|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer |
|
|
|
from mlagents.trainers.torch.layers import ( |
|
|
|
LinearEncoder, |
|
|
|
Initialization, |
|
|
|
linear_layer, |
|
|
|
LayerNorm, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
|
|
|
|
|
|
|
# If not concatenating self, input to encoder is just entity size |
|
|
|
if not concat_self: |
|
|
|
self.self_size = 0 |
|
|
|
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf |
|
|
|
LinearEncoder(self.self_size + ent_size, 1, embedding_size) |
|
|
|
LinearEncoder( |
|
|
|
self.self_size + ent_size, |
|
|
|
1, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
self.embedding_norm = LayerNorm() |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, x_self: torch.Tensor, entities: List[torch.Tensor] |
|
|
|
|
|
|
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], |
|
|
|
dim=1, |
|
|
|
) |
|
|
|
encoded_entities = self.embedding_norm(encoded_entities) |
|
|
|
return encoded_entities |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
|
num_heads=num_heads, embedding_size=embedding_size |
|
|
|
) |
|
|
|
|
|
|
|
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf |
|
|
|
self.fc_q = linear_layer( |
|
|
|
embedding_size, |
|
|
|
embedding_size, |
|
|
|
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
self.residual_norm = LayerNorm() |
|
|
|
|
|
|
|
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
# Gather the maximum number of entities information |
|
|
|
|
|
|
output, _ = self.attention(query, key, value, num_ent, num_ent, mask) |
|
|
|
# Residual |
|
|
|
output = self.fc_out(output) + inp |
|
|
|
output = self.residual_norm(output) |
|
|
|
# Residual between x_self and the output of the module |
|
|
|
return output |