浏览代码

fixing tests

/layernorm
vincentpierre 4 年前
当前提交
3fe74831
共有 2 个文件被更改,包括 5 次插入5 次删除
  1. 8
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  2. 2
      ml-agents-envs/mlagents_envs/tests/test_steps.py

8
ml-agents-envs/mlagents_envs/tests/test_envs.py


assert len(spec.sensor_specs) == len(decision_steps.obs)
assert len(spec.sensor_specs) == len(terminal_steps.obs)
n_agents = len(decision_steps)
for spec, obs in zip(spec.sensor_specs, decision_steps.obs):
assert (n_agents,) + spec.shape == obs.shape
for sen_spec, obs in zip(spec.sensor_specs, decision_steps.obs):
assert (n_agents,) + sen_spec.shape == obs.shape
for spec, obs in zip(spec.sensor_specs.shapes, terminal_steps.obs):
assert (n_agents,) + spec.shape == obs.shape
for sen_spec, obs in zip(spec.sensor_specs, terminal_steps.obs):
assert (n_agents,) + sen_spec.shape == obs.shape
@mock.patch("mlagents_envs.env_utils.launch_executable")

2
ml-agents-envs/mlagents_envs/tests/test_steps.py


def test_empty_decision_steps():
specs = BehaviorSpec(
sensor_specs=create_sensor_specs_with_shapes([(3, 2), (5,)]),
action_specs=ActionSpec.create_continuous(3),
action_spec=ActionSpec.create_continuous(3),
)
ds = DecisionSteps.empty(specs)
assert len(ds.obs) == 2

正在加载...
取消
保存