浏览代码

Fix and test for masked_mean

/develop/add-fire/sac-lst
Ervin Teng 4 年前
当前提交
a88d3581
共有 2 个文件被更改,包括 17 次插入1 次删除
  1. 16
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  2. 2
      ml-agents/mlagents/trainers/torch/utils.py

16
ml-agents/mlagents/trainers/tests/torch/test_utils.py


assert entropies.shape == (1, len(dist_list))
# Make sure the first action has high probability than the others.
assert log_probs.flatten()[0] > log_probs.flatten()[1]
def test_masked_mean():
test_input = torch.tensor([1, 2, 3, 4, 5])
masks = torch.ones_like(test_input).bool()
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 3.0
masks = torch.tensor([False, False, True, True, True])
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 4.0
# Make sure it works if all masks are off
masks = torch.tensor([False, False, False, False, False])
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 0.0

2
ml-agents/mlagents/trainers/torch/utils.py


:param tensor: Tensor which needs mean computation.
:param masks: Boolean tensor of masks with same dimension as tensor.
"""
return (tensor * masks).sum() / masks.float().sum()
return (tensor * masks).sum() / torch.clamp(masks.float().sum(), min=1.0)
正在加载...
取消
保存