浏览代码

fix torch utils test

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
230497f5
共有 1 个文件被更改,包括 9 次插入6 次删除
  1. 15
      ml-agents/mlagents/trainers/tests/torch/test_utils.py

15
ml-agents/mlagents/trainers/tests/torch/test_utils.py


log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert log_probs.shape == (1, 2, 2)
for lp in log_probs:
assert lp.shape == (1, 2)
assert all_probs is None
assert all_probs == []
for log_prob in log_probs.flatten():
for log_prob in log_probs:
assert log_prob == pytest.approx(-0.919, abs=0.01)
for lp in log_prob.flatten():
assert lp == pytest.approx(-0.919, abs=0.01)
for ent in entropies.flatten():
# entropy of standard normal at 0

log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy(
action_list, dist_list
)
assert all_probs.shape == (1, len(dist_list * act_size))
for all_prob in all_probs:
assert all_prob.shape == (1, act_size)
assert log_probs.flatten()[0] > log_probs.flatten()[1]
assert log_probs[0] > log_probs[1]
def test_masked_mean():

正在加载...
取消
保存