|
|
|
|
|
|
Multi Head Attention module. We do not use the regular Torch implementation since |
|
|
|
Barracuda does not support some operators it uses. |
|
|
|
Takes as input to the forward method 3 tensors: |
|
|
|
- query: of dimensions (batch_size, number_of_queries, key_size) |
|
|
|
- key: of dimensions (batch_size, number_of_keys, key_size) |
|
|
|
- value: of dimensions (batch_size, number_of_keys, value_size) |
|
|
|
- query: of dimensions (batch_size, number_of_queries, embedding_size) |
|
|
|
- key: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
- value: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
- The output: (batch_size, number_of_queries, output_size) |
|
|
|
- The output: (batch_size, number_of_queries, embedding_size) |
|
|
|
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
query: torch.Tensor, |
|
|
|
key: torch.Tensor, |
|
|
|
value: torch.Tensor, |
|
|
|
n_q: int, |
|
|
|
n_k: int, |
|
|
|
n_q: int = -1, |
|
|
|
n_k: int = -1, |
|
|
|
# This is to avoid using .size() when possible as Barracuda does not support |
|
|
|
|
|
|
|
query = query.reshape( |
|
|
|
b, n_q, self.n_heads, self.head_size |
|
|
|
|
|
|
relevant "Entities". |
|
|
|
""" |
|
|
|
|
|
|
|
EPISLON = 1e-7 |
|
|
|
EPSILON = 1e-7 |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
|
|
|
key = self.fc_k(inp) # (b, n_k, emb) |
|
|
|
value = self.fc_v(inp) # (b, n_k, emb) |
|
|
|
output, _ = self.attention( |
|
|
|
query, key, value, mask, self.max_num_ent, self.max_num_ent |
|
|
|
query, key, value, self.max_num_ent, self.max_num_ent, mask |
|
|
|
) |
|
|
|
# Residual |
|
|
|
output = self.fc_out(output) + inp |
|
|
|
|
|
|
) |
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON |
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON |
|
|
|
output = numerator / denominator |
|
|
|
# Residual between x_self and the output of the module |
|
|
|
return output |