浏览代码

Fix util test

/develop/add-fire/categoricaldist
Ervin Teng 4 年前
当前提交
5bf72236
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 6
      ml-agents/mlagents/trainers/tests/torch/test_utils.py

6
ml-agents/mlagents/trainers/tests/torch/test_utils.py


# Add two dists to the list.
act_size = 2
test_prob = torch.tensor(
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)]
) # High prob for first action
dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)]
action_list = [torch.tensor([0]), torch.tensor([1])]

assert all_probs.shape == (len(dist_list * act_size),)
assert entropies.shape == (len(dist_list),)
assert all_probs.shape == (1, len(dist_list * act_size))
assert entropies.shape == (1, len(dist_list))
# Make sure the first action has high probability than the others.
assert log_probs.flatten()[0] > log_probs.flatten()[1]
正在加载...
取消
保存