浏览代码

Merge pull request #4714 from Unity-Technologies/develop-fix-mask

Fix masking for torch
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
2c744129
共有 2 个文件被更改,包括 11 次插入9 次删除
  1. 4
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 16
      ml-agents/mlagents/trainers/torch/distributions.py

4
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


if (maskActions)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var positionX = (int)transform.localPosition.x;
var positionZ = (int)transform.localPosition.z;
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;
if (positionX == 0)

16
ml-agents/mlagents/trainers/torch/distributions.py


branches.append(branch_output_layer)
return nn.ModuleList(branches)
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def _mask_branch(
self, logits: torch.Tensor, allow_mask: torch.Tensor
) -> torch.Tensor:
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barrcuda-friendly.
flipped_mask = 1.0 - mask
adj_logits = logits * mask - 1e8 * flipped_mask
probs = torch.nn.functional.softmax(adj_logits, dim=-1)
log_probs = torch.log(probs + EPSILON)
return log_probs
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barracuda-friendly.
block_mask = -1.0 * allow_mask + 1.0
# We do -1 * tensor + constant instead of constant - tensor because it seems
# Barracuda might swap the inputs of a "Sub" operation
logits = logits * allow_mask - 1e8 * block_mask
return logits
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
split_masks = []

正在加载...
取消
保存