|
|
|
|
|
|
query: torch.Tensor, |
|
|
|
key: torch.Tensor, |
|
|
|
value: torch.Tensor, |
|
|
|
key_mask: torch.Tensor, |
|
|
|
key_mask: Optional[torch.Tensor] = None, |
|
|
|
number_of_keys: int = -1, |
|
|
|
number_of_queries: int = -1 |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
n_k = number_of_keys if number_of_keys != -1 else key.size(1) |
|
|
|
|
|
|
|
# Create a key mask : Only 1 if all values are 0 # shape = (b, n_k) |
|
|
|
# key_mask = torch.sum(key ** 2, axis=2) < 0.01 |
|
|
|
key_mask = key_mask.reshape(b, 1, 1, n_k) |
|
|
|
|
|
|
|
query = self.fc_q(query) # (b, n_q, h*d) |
|
|
|
key = self.fc_k(key) # (b, n_k, h*d) |
|
|
|
|
|
|
|
|
|
|
qk = torch.matmul(query, key) # (b, h, n_q, n_k) |
|
|
|
|
|
|
|
qk = (1 - key_mask) * qk / (self.embedding_size ** 0.5) + key_mask * self.NEG_INF |
|
|
|
if key_mask is None: |
|
|
|
qk = qk / (self.embedding_size ** 0.5) |
|
|
|
else: |
|
|
|
key_mask = key_mask.reshape(b, 1, 1, n_k) |
|
|
|
qk = (1 - key_mask) * qk / (self.embedding_size ** 0.5) + key_mask * self.NEG_INF |
|
|
|
|
|
|
|
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k) |
|
|
|
|
|
|
|
|
|
|
out = self.fc_out(value_attention) # (b, n_q, emb) |
|
|
|
return out, att |
|
|
|
|
|
|
|
|
|
|
|
class ZeroObservationMask(torch.nn.Module): |
|
|
|
""" |
|
|
|
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was |
|
|
|
all zeros and 0 otherwise. This is used in the Attention layer to mask the padding |
|
|
|
observations. |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def forward(self, observations: List[torch.Tensor]): |
|
|
|
with torch.no_grad(): |
|
|
|
# Generate the masking tensors for each entities tensor (mask only if all zeros) |
|
|
|
key_masks: List[torch.Tensor] = [ |
|
|
|
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor) |
|
|
|
for ent in observations |
|
|
|
] |
|
|
|
return key_masks |
|
|
|
|
|
|
|
|
|
|
|
class SimpleTransformer(torch.nn.Module): |
|
|
|
""" |
|
|
|
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|
|
|
|
|
|
EPISLON = 1e-7 |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, x_self_size: int, entities_sizes: List[int], embedding_size: int |
|
|
|
self, x_self_size: int, entities_sizes: List[int], embedding_size: int, output_size: Optional[int] = None |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.self_size = x_self_size |
|
|
|
|
|
|
embedding_size=embedding_size, |
|
|
|
) |
|
|
|
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size) |
|
|
|
if output_size is None: |
|
|
|
output_size = embedding_size |
|
|
|
self.x_self_residual_layer = LinearEncoder(embedding_size + x_self_size, 1, output_size) |
|
|
|
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor], key_masks: List[torch.Tensor]): |
|
|
|
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor], key_masks: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
# Gather the maximum number of entities information |
|
|
|
if self.entities_num_max_elements is None: |
|
|
|
self.entities_num_max_elements = [] |
|
|
|
|
|
|
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1) |
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON |
|
|
|
output = numerator / denominator |
|
|
|
# Residual between x_self and the output of the module |
|
|
|
output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1)) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was |
|
|
|
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention |
|
|
|
layer to mask the padding observations. |
|
|
|
""" |
|
|
|
with torch.no_grad(): |
|
|
|
# Generate the masking tensors for each entities tensor (mask only if all zeros) |
|
|
|
key_masks: List[torch.Tensor] = [ |
|
|
|
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor) |
|
|
|
for ent in observations |
|
|
|
] |
|
|
|
return key_masks |