浏览代码

Add another epsilon

/develop/fix-nan
Ervin Teng 4 年前
当前提交
eb4f3065
共有 1 个文件被更改,包括 3 次插入1 次删除
  1. 4
      ml-agents/mlagents/trainers/torch/distributions.py

4
ml-agents/mlagents/trainers/torch/distributions.py


def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1)
normalized_probs = raw_probs / (
torch.sum(raw_probs, dim=-1).unsqueeze(-1) + EPSILON
)
normalized_logits = torch.log(normalized_probs + EPSILON)
return normalized_logits

正在加载...
取消
保存