浏览代码

fix test policy

/develop/action-slice
Andrew Cohen 4 年前
当前提交
3f7d68b8
共有 1 个文件被更改,包括 1 次插入3 次删除
  1. 4
      ml-agents/mlagents/trainers/tests/torch/test_policy.py

4
ml-agents/mlagents/trainers/tests/torch/test_policy.py


if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
log_probs, entropy, values = policy.evaluate_actions(
log_probs, entropy = policy.evaluate_actions(
tensor_obs,
masks=act_masks,
actions=agent_action,

assert log_probs.flatten().shape == (64, _size)
assert entropy.shape == (64,)
for val in values.values():
assert val.shape == (64,)
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])

正在加载...
取消
保存