浏览代码

Fix test and replace range with arange

/develop/add-fire/categoricaldist
Ervin Teng 4 年前
当前提交
6b29a4c9
共有 2 个文件被更改,包括 6 次插入4 次删除
  1. 4
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  2. 6
      ml-agents/mlagents/trainers/torch/distributions.py

4
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


torch.manual_seed(0)
act_size = 4
test_prob = torch.tensor(
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)]
assert action.shape == (1,)
assert action.shape == (1, 1)
assert action < act_size
# Make sure the first action as higher probability than the others.

6
ml-agents/mlagents/trainers/torch/distributions.py


return torch.multinomial(self.probs, 1)
def pdf(self, value):
idx = torch.range(end=len(value)).unsqueeze(-1)
return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1)
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1)
return torch.gather(
self.probs.permute(1, 0)[value.flatten().long()], -1, idx
).squeeze(-1)
def log_prob(self, value):
return torch.log(self.pdf(value))

正在加载...
取消
保存