|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer |
|
|
|
from mlagents.trainers.torch.layers import ( |
|
|
|
LinearEncoder, |
|
|
|
Initialization, |
|
|
|
linear_layer, |
|
|
|
LayerNorm, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
|
|
|
|
|
|
|
for ent_size in self.entity_sizes |
|
|
|
] |
|
|
|
) |
|
|
|
self.embedding_norm = LayerNorm() |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, x_self: torch.Tensor, entities: List[torch.Tensor] |
|
|
|
|
|
|
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], |
|
|
|
dim=1, |
|
|
|
) |
|
|
|
encoded_entities = self.embedding_norm(encoded_entities) |
|
|
|
return encoded_entities |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
self.residual_norm = LayerNorm() |
|
|
|
|
|
|
|
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
# Gather the maximum number of entities information |
|
|
|
|
|
|
output, _ = self.attention(query, key, value, num_ent, num_ent, mask) |
|
|
|
# Residual |
|
|
|
output = self.fc_out(output) + inp |
|
|
|
output = self.residual_norm(output) |
|
|
|
# Residual between x_self and the output of the module |
|
|
|
return output |