浏览代码

Fix gym tests

/MLA-1734-demo-provider
Arthur Juliani 3 年前
当前提交
91a93815
共有 2 个文件被更改,包括 11 次插入11 次删除
  1. 16
      gym-unity/gym_unity/envs/__init__.py
  2. 6
      gym-unity/gym_unity/tests/test_gym.py

16
gym-unity/gym_unity/envs/__init__.py


def _get_n_vis_obs(self) -> int:
result = 0
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 3:
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
result.append(sen_spec.shape)
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 3:
result.append(obs_spec.shape)
return result
def _get_vis_obs_list(

def _get_vec_obs_size(self) -> int:
result = 0
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 1:
result += sen_spec.shape[0]
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 1:
result += obs_spec.shape[0]
return result
def render(self, mode="rgb_array"):

6
gym-unity/gym_unity/tests/test_gym.py


TerminalSteps,
BehaviorMapping,
)
from mlagents.trainers.tests.dummy_config import create_sensor_specs_with_shapes
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes
def test_gym_wrapper():

obs_shapes = [(vector_observation_space_size,)]
for _ in range(number_visual_observations):
obs_shapes += [(8, 8, 3)]
sen_spec = create_sensor_specs_with_shapes(obs_shapes)
return BehaviorSpec(sen_spec, action_spec)
obs_spec = create_observation_specs_with_shapes(obs_shapes)
return BehaviorSpec(obs_spec, action_spec)
def create_mock_vector_steps(specs, num_agents=1, number_visual_observations=0):

正在加载...
取消
保存