浏览代码

Added comment

/develop/add-fire/categoricaldist
Ervin Teng 4 年前
当前提交
6aa6c931
共有 1 个文件被更改,包括 2 次插入0 次删除
  1. 2
      ml-agents/mlagents/trainers/torch/distributions.py

2
ml-agents/mlagents/trainers/torch/distributions.py


return torch.multinomial(self.probs, 1)
def pdf(self, value):
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),
# but torch.diag is not supported by ONNX export.
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1)
return torch.gather(
self.probs.permute(1, 0)[value.flatten().long()], -1, idx

正在加载...
取消
保存