|
|
|
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
|
|
|
|
|
|
|
|
def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was |
|
|
|
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention |
|
|
|
|
|
|
|
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
# When exporting to ONNX, we want to transpose the entities. This is |
|
|
|
# because ONNX only support input in NCHW (channel first) format. |
|
|
|
# Barracuda also expect to get data in NCHW. |
|
|
|
entities = [ |
|
|
|
torch.transpose(obs, 2, 1).reshape( |
|
|
|
-1, int(obs.shape[1]), int(obs.shape[2]) |
|
|
|
) |
|
|
|
for obs in entities |
|
|
|
] |
|
|
|
|
|
|
|
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations |
|
|
|
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in entities |
|
|
|
] |
|
|
|
return key_masks |
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: |
|
|
|
num_entities = self.entity_num_max_elements |
|
|
|
if num_entities < 0: |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
num_entities = entities.shape[1] |
|
|
|
|
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
# When exporting to ONNX, we want to transpose the entities. This is |
|
|
|
# because ONNX only support input in NCHW (channel first) format. |
|
|
|
# Barracuda also expect to get data in NCHW. |
|
|
|
entities = torch.transpose(entities, 2, 1).reshape( |
|
|
|
-1, num_entities, self.entity_size |
|
|
|
) |
|
|
|
|
|
|
|
num_entities = self.entity_num_max_elements |
|
|
|
if num_entities < 0: |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
num_entities = entities.shape[1] |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
# Concatenate all observations with self |
|
|
|