浏览代码

[bug-fix] Fix issue where NaNs are outputted by the policy when training Match3 (#4664)

* match3 settings

* Add epsilon to log

* Add another epsilon

* Revert match3 configs

* NaN-free masking method

* Add comment for paper

* Add comment for paper

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
de27d7a6
共有 1 个文件被更改,包括 10 次插入7 次删除
  1. 17
      ml-agents/mlagents/trainers/torch/distributions.py

17
ml-agents/mlagents/trainers/torch/distributions.py


).squeeze(-1)
def log_prob(self, value):
return torch.log(self.pdf(value))
return torch.log(self.pdf(value) + EPSILON)
return torch.log(self.probs)
return torch.log(self.probs + EPSILON)
return -torch.sum(self.probs * torch.log(self.probs), dim=-1)
return -torch.sum(self.probs * torch.log(self.probs + EPSILON), dim=-1)
class GaussianDistribution(nn.Module):

return nn.ModuleList(branches)
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1)
normalized_logits = torch.log(normalized_probs + EPSILON)
return normalized_logits
# Zero out masked logits, then subtract a large value. Technique mentionend here:
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barrcuda-friendly.
flipped_mask = 1.0 - mask
adj_logits = logits * mask - 1e8 * flipped_mask
probs = torch.nn.functional.softmax(adj_logits, dim=-1)
log_probs = torch.log(probs + EPSILON)
return log_probs
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
split_masks = []

正在加载...
取消
保存