浏览代码

fix torch test policy

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
e88558c3
共有 1 个文件被更改,包括 2 次插入8 次删除
  1. 10
      ml-agents/mlagents/trainers/tests/torch/test_policy.py

10
ml-agents/mlagents/trainers/tests/torch/test_policy.py


run_out = policy.evaluate(decision_step, list(decision_step.agent_id))
if discrete:
run_out["action"]["discrete_action"].shape == (
NUM_AGENTS,
len(DISCRETE_ACTION_SPACE),
)
run_out["action"].discrete.shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE))
assert run_out["action"]["continuous_action"].shape == (
NUM_AGENTS,
VECTOR_ACTION_SPACE,
)
assert run_out["action"].continuous.shape == (NUM_AGENTS, VECTOR_ACTION_SPACE)
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])

正在加载...
取消
保存