浏览代码

Fix masking for torch

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
90da7426
共有 2 个文件被更改,包括 8 次插入6 次删除
  1. 4
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 10
      ml-agents/mlagents/trainers/torch/distributions.py

4
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


if (maskActions)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var positionX = (int)transform.localPosition.x;
var positionZ = (int)transform.localPosition.z;
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;
if (positionX == 0)

10
ml-agents/mlagents/trainers/torch/distributions.py


branches.append(branch_output_layer)
return nn.ModuleList(branches)
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def _mask_branch(
self, logits: torch.Tensor, allow_mask: torch.Tensor
) -> torch.Tensor:
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barrcuda-friendly.
flipped_mask = 1.0 - mask
adj_logits = logits * mask - 1e8 * flipped_mask
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barracuda-friendly.
mask = -1.0 * allow_mask + 1.0
adj_logits = logits * allow_mask - 1e8 * mask
probs = torch.nn.functional.softmax(adj_logits, dim=-1)
log_probs = torch.log(probs + EPSILON)
return log_probs

正在加载...
取消
保存