|
|
|
|
|
|
return value_attention, att |
|
|
|
|
|
|
|
|
|
|
|
class EntityEmbeddings(torch.nn.Module): |
|
|
|
class EntityEmbedding(torch.nn.Module): |
|
|
|
entity_sizes: List[int], |
|
|
|
entity_num_max_elements: Optional[List[int]], |
|
|
|
entity_size: int, |
|
|
|
entity_num_max_elements: Optional[int], |
|
|
|
embedding_size: int, |
|
|
|
concat_self: bool = True, |
|
|
|
): |
|
|
|
|
|
|
:param entity_sizes: List of sizes for other entities. Should be of length |
|
|
|
equivalent to the number of entities. |
|
|
|
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted. |
|
|
|
:param entity_size: Size of other entitiy. |
|
|
|
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted. |
|
|
|
:param embedding_size: Embedding size for entity encoders. |
|
|
|
:param embedding_size: Embedding size for the entity encoder. |
|
|
|
self.entity_sizes: List[int] = entity_sizes |
|
|
|
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes) |
|
|
|
self.entity_size: int = entity_size |
|
|
|
self.entity_num_max_elements: int = -1 |
|
|
|
if entity_num_max_elements is not None: |
|
|
|
self.entity_num_max_elements = entity_num_max_elements |
|
|
|
|
|
|
|
|
|
|
self.self_size = 0 |
|
|
|
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf |
|
|
|
self.ent_encoders = torch.nn.ModuleList( |
|
|
|
[ |
|
|
|
LinearEncoder( |
|
|
|
self.self_size + ent_size, |
|
|
|
1, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
for ent_size in self.entity_sizes |
|
|
|
] |
|
|
|
self.ent_encoder = LinearEncoder( |
|
|
|
self.self_size + self.entity_size, |
|
|
|
1, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
self.embedding_norm = LayerNorm() |
|
|
|
def forward( |
|
|
|
self, x_self: torch.Tensor, entities: List[torch.Tensor] |
|
|
|
) -> torch.Tensor: |
|
|
|
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: |
|
|
|
num_entities = self.entity_num_max_elements |
|
|
|
if num_entities < 0: |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
num_entities = entities.shape[1] |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
self_and_ent: List[torch.Tensor] = [] |
|
|
|
for num_entities, ent in zip(self.entity_num_max_elements, entities): |
|
|
|
if num_entities < 0: |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
num_entities = ent.shape[1] |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) |
|
|
|
else: |
|
|
|
self_and_ent = entities |
|
|
|
# Encode and concatenate entities |
|
|
|
encoded_entities = torch.cat( |
|
|
|
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], |
|
|
|
dim=1, |
|
|
|
) |
|
|
|
encoded_entities = self.embedding_norm(encoded_entities) |
|
|
|
entities = torch.cat([expanded_self, entities], dim=2) |
|
|
|
# Encode entities |
|
|
|
encoded_entities = self.ent_encoder(entities) |
|
|
|
@staticmethod |
|
|
|
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was |
|
|
|
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention |
|
|
|
layer to mask the padding observations. |
|
|
|
""" |
|
|
|
with torch.no_grad(): |
|
|
|
# Generate the masking tensors for each entities tensor (mask only if all zeros) |
|
|
|
key_masks: List[torch.Tensor] = [ |
|
|
|
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor) |
|
|
|
for ent in observations |
|
|
|
] |
|
|
|
return key_masks |
|
|
|
|
|
|
|
|
|
|
|
class ResidualSelfAttention(torch.nn.Module): |
|
|
|
""" |
|
|
|
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
self.embedding_norm = LayerNorm() |
|
|
|
|
|
|
|
inp = self.embedding_norm(inp) |
|
|
|
# Feed to self attention |
|
|
|
query = self.fc_q(inp) # (b, n_q, emb) |
|
|
|
key = self.fc_k(inp) # (b, n_k, emb) |
|
|
|
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON |
|
|
|
output = numerator / denominator |
|
|
|
return output |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was |
|
|
|
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention |
|
|
|
layer to mask the padding observations. |
|
|
|
""" |
|
|
|
with torch.no_grad(): |
|
|
|
# Generate the masking tensors for each entities tensor (mask only if all zeros) |
|
|
|
key_masks: List[torch.Tensor] = [ |
|
|
|
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor) |
|
|
|
for ent in observations |
|
|
|
] |
|
|
|
return key_masks |