|
|
|
|
|
|
entities_sizes = [sensor_specs[idx].shape[1] for idx in var_len_indices] |
|
|
|
entities_max_len = [sensor_specs[idx].shape[0] for idx in var_len_indices] |
|
|
|
|
|
|
|
self.x_self_encoder = LinearEncoder(x_self_len, 2, self.h_size // 2) |
|
|
|
self.var_len_encoders = torch.nn.ModuleList( |
|
|
|
[ |
|
|
|
LinearEncoder(ent_size, 2, self.h_size // 2) |
|
|
|
for ent_size in entities_sizes |
|
|
|
] |
|
|
|
) |
|
|
|
x_self_len, entities_sizes, entities_max_len, self.h_size |
|
|
|
self.h_size // 2, |
|
|
|
[self.h_size // 2] * len(var_len_indices), |
|
|
|
entities_max_len, |
|
|
|
self.h_size, |
|
|
|
total_enc_size = x_self_len + self.h_size |
|
|
|
total_enc_size = self.h_size // 2 + self.h_size |
|
|
|
|
|
|
|
n_layers = max(1, network_settings.num_layers - 2) |
|
|
|
else: |
|
|
|
|
|
|
if len(var_len_inputs) > 0: |
|
|
|
# Some inputs need to be processed with a variable length encoder |
|
|
|
masks = EntityEmbeddings.get_masks(var_len_inputs) |
|
|
|
qkv = self.entities_embeddings(encoded_self, var_len_inputs) |
|
|
|
encoded_self = self.x_self_encoder(encoded_self) |
|
|
|
encoded_var_len = [ |
|
|
|
encoder(x) for encoder, x in zip(self.var_len_encoders, var_len_inputs) |
|
|
|
] |
|
|
|
qkv = self.entities_embeddings(encoded_self, encoded_var_len) |
|
|
|
mu_qkv = torch.mean(qkv, dim=2, keepdim=True) |
|
|
|
qkv = (qkv - mu_qkv) / ( |
|
|
|
torch.sqrt(torch.mean((qkv - mu_qkv) ** 2, dim=2, keepdim=True)) |
|
|
|