|
|
|
|
|
|
encoder(x) for encoder, x in zip(self.var_len_encoders, var_len_inputs) |
|
|
|
] |
|
|
|
qkv = self.entities_embeddings(encoded_self, encoded_var_len) |
|
|
|
mu_qkv = torch.mean(qkv, dim=2, keepdim=True) |
|
|
|
qkv = (qkv - mu_qkv) / ( |
|
|
|
torch.sqrt(torch.mean((qkv - mu_qkv) ** 2, dim=2, keepdim=True)) |
|
|
|
+ 0.0001 |
|
|
|
) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) |
|
|
|
|
|
|
|