浏览代码

Fix discrete export (#4322)

Fix discrete export
/develop/add-fire
GitHub 4 年前
当前提交
dba529ff
共有 1 个文件被更改,包括 2 次插入1 次删除
  1. 3
      ml-agents/mlagents/trainers/torch/distributions.py

3
ml-agents/mlagents/trainers/torch/distributions.py


return torch.multinomial(self.probs, 1)
def pdf(self, value):
return torch.diag(self.probs.T[value.flatten().long()])
idx = torch.range(end=len(value)).unsqueeze(-1)
return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1)
def log_prob(self, value):
return torch.log(self.pdf(value))

正在加载...
取消
保存