浏览代码

Added a comment and included the change of #4715 for simplicity

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
e85d8e35
共有 1 个文件被更改,包括 5 次插入5 次删除
  1. 10
      ml-agents/mlagents/trainers/torch/distributions.py

10
ml-agents/mlagents/trainers/torch/distributions.py


) -> torch.Tensor:
# Zero out masked logits, then subtract a large value. Technique mentionend here:
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barracuda-friendly.
mask = -1.0 * allow_mask + 1.0
adj_logits = logits * allow_mask - 1e8 * mask
probs = torch.nn.functional.softmax(adj_logits, dim=-1)
log_probs = torch.log(probs + EPSILON)
return log_probs
block_mask = -1.0 * allow_mask + 1.0
# We do -1 * tensor + constant instead of constant - tensor because it seems
# Barracuda might swap the inputs of a "Sub" operation
logits = logits * allow_mask - 1e8 * block_mask
return logits
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
split_masks = []

正在加载...
取消
保存