|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
""" |
|
|
|
Multi Head Attention module. We do not use the regular Torch implementation since |
|
|
|
Barracuda does not support some operators it uses. |
|
|
|
Takes as input to the forward method 3 tensors: |
|
|
|
- query: of dimensions (batch_size, number_of_queries, embedding_size) |
|
|
|
- key: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
- value: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
The forward method will return 2 tensors: |
|
|
|
- The output: (batch_size, number_of_queries, embedding_size) |
|
|
|
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) |
|
|
|
""" |
|
|
|
""" |
|
|
|
Multi Head Attention module. We do not use the regular Torch implementation since |
|
|
|
Barracuda does not support some operators it uses. |
|
|
|
Takes as input to the forward method 3 tensors: |
|
|
|
- query: of dimensions (batch_size, number_of_queries, embedding_size) |
|
|
|
- key: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
- value: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
The forward method will return 2 tensors: |
|
|
|
- The output: (batch_size, number_of_queries, embedding_size) |
|
|
|
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) |
|
|
|
:param embedding_size: The size of the embeddings that will be generated (should be |
|
|
|
dividable by the num_heads) |
|
|
|
:param total_max_elements: The maximum total number of entities that can be passed to |
|
|
|
the module |
|
|
|
:param num_heads: The number of heads of the attention module |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.n_heads = num_heads |
|
|
|
self.head_size: int = embedding_size // self.n_heads |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EntityEmbeddings(torch.nn.Module): |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
x_self_size: int, |
|
|
|
|
|
|
concat_self: bool = True, |
|
|
|
): |
|
|
|
""" |
|
|
|
A module that generates embeddings for a variable number of entities given "self" |
|
|
|
encoding as well as a representation of each entity of each type. The expected |
|
|
|
input of the forward method will be a list of tensors: each element of the |
|
|
|
list corresponds to a type of entity, the dimension of each tensor is : |
|
|
|
[batch_size, max_num_entities, entity_size] |
|
|
|
:param x_self_size: The size of the self embedding that will be concatenated |
|
|
|
with the entities |
|
|
|
:param entity_sizes: The size of each entity type |
|
|
|
:param entity_num_max_elements: A list of maximum number of entities, must be |
|
|
|
the same length as the number of entity tensors that will be passed to the |
|
|
|
forward method. |
|
|
|
:param embedding_size: The size of the output embeddings |
|
|
|
:param concat_self: If true, the x_self will be concatenated with the entities |
|
|
|
before embedding |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.self_size: int = x_self_size |
|
|
|
self.entity_sizes: List[int] = entity_sizes |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, x_self: torch.Tensor, entities: List[torch.Tensor] |
|
|
|
) -> Tuple[torch.Tensor, int]: |
|
|
|
) -> torch.Tensor: |
|
|
|
if self.concat_self: |
|
|
|
# Concatenate all observations with self |
|
|
|
self_and_ent: List[torch.Tensor] = [] |
|
|
|
|
|
|
self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) |
|
|
|
else: |
|
|
|
self_and_ent = entities |
|
|
|
# Encode and concatenate entites |
|
|
|
# Encode and concatenate entities |
|
|
|
encoded_entities = torch.cat( |
|
|
|
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], |
|
|
|
dim=1, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualSelfAttention(torch.nn.Module): |
|
|
|
""" |
|
|
|
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|
|
|
multi head self attention to encode information about a "Self" and a list of |
|
|
|
relevant "Entities". |
|
|
|
""" |
|
|
|
|
|
|
|
self, |
|
|
|
embedding_size: int, |
|
|
|
entity_num_max_elements: List[int], |
|
|
|
num_heads: int = 4, |
|
|
|
self, embedding_size: int, total_max_elements: int, num_heads: int = 4 |
|
|
|
""" |
|
|
|
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|
|
|
multi head self attention to encode information about a "Self" and a list of |
|
|
|
relevant "Entities". |
|
|
|
:param embedding_size: The size of the embeddings that will be generated (should be |
|
|
|
dividable by the num_heads) |
|
|
|
:param total_max_elements: The maximum total number of entities that can be passed to |
|
|
|
the module |
|
|
|
:param num_heads: The number of heads of the attention module |
|
|
|
""" |
|
|
|
self.entity_num_max_elements: List[int] = entity_num_max_elements |
|
|
|
self.max_num_ent = sum(entity_num_max_elements) |
|
|
|
self.max_num_ent = total_max_elements |
|
|
|
self.attention = MultiHeadAttention( |
|
|
|
num_heads=num_heads, embedding_size=embedding_size |
|
|
|
) |
|
|
|