浏览代码

Changing model export to be compatible with Barracuda

/bullet-hell-barracuda-test-1.3.1
vincentpierre 4 年前
当前提交
8baaaf4d
共有 1 个文件被更改,包括 31 次插入10 次删除
  1. 41
      ml-agents/mlagents/trainers/torch/attention.py

41
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.trainers.exception import UnityTrainerException
def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]:
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention

if exporting_to_onnx.is_exporting():
# When exporting to ONNX, we want to transpose the entities. This is
# because ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
entities = [
torch.transpose(obs, 2, 1).reshape(
-1, int(obs.shape[1]), int(obs.shape[2])
)
for obs in entities
]
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in entities
]
return key_masks

)
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
if exporting_to_onnx.is_exporting():
# When exporting to ONNX, we want to transpose the entities. This is
# because ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
entities = torch.transpose(entities, 2, 1).reshape(
-1, num_entities, self.entity_size
)
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# Concatenate all observations with self

正在加载...
取消
保存