|
|
|
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
|
|
|
|
self.self_embedding = LinearEncoder(6, 1, 64) |
|
|
|
|
|
|
|
|
|
|
# TODO : This is a Hack |
|
|
|
var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
key_mask = 0 * ( |
|
|
|
key_mask = ( |
|
|
|
torch.sum(var_len_input ** 2, axis=2) < 0.01 |
|
|
|
) # 1 means mask and 0 means let though |
|
|
|
|
|
|
|
|
|
|
objects = torch.cat([expanded_x_self, var_len_input], dim=2) #(b,20,68) |
|
|
|
objects = torch.cat([expanded_x_self, var_len_input], dim=2) # (b,20,68) |
|
|
|
obj_encoding = self.obs_embeding(objects)#(b,20,64) |
|
|
|
obj_encoding = self.obs_embeding(objects) # (b,20,64) |
|
|
|
self_and_key_emb = torch.cat([x_self.reshape(-1,1,64), obj_encoding], dim=1) #(b,21,64) |
|
|
|
self_and_key_emb = torch.cat( |
|
|
|
[x_self.reshape(-1, 1, 64), obj_encoding], dim=1 |
|
|
|
) # (b,21,64) |
|
|
|
) # first one is never masked |
|
|
|
) # first one is never masked |
|
|
|
|
|
|
|
output, _ = self.attention( |
|
|
|
self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask |
|
|
|
|
|
|
|
|
|
|
) / torch.sum(1 - key_mask, dim=1, keepdim=True) # average pooling |
|
|
|
|
|
|
|
|
|
|
|
) / torch.sum( |
|
|
|
1 - key_mask, dim=1, keepdim=True |
|
|
|
) # average pooling |
|
|
|
|
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|