|
|
|
|
|
|
else 0 |
|
|
|
) |
|
|
|
|
|
|
|
self.visual_processors, self.vector_processors, self.attention, encoder_input_size = ModelUtils.create_input_processors( |
|
|
|
self.visual_processors, self.vector_processors, self.attention, _ = ModelUtils.create_input_processors( |
|
|
|
observation_shapes, |
|
|
|
self.h_size, |
|
|
|
network_settings.vis_encode_type, |
|
|
|
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
# total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
self.self_embedding = LinearEncoder(6, 1, 64) |
|
|
|
self.obs_embeding = LinearEncoder(4 + 64, 1, 64) |
|
|
|
self.self_embedding = LinearEncoder(6, 2, 64) |
|
|
|
self.obs_embeding = LinearEncoder(4, 2, 64) |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
64 * 2, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
|
|
|
|
# TODO : This is a Hack |
|
|
|
var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
key_mask = ( |
|
|
|
key_mask = ( |
|
|
|
) # 1 means mask and 0 means let though |
|
|
|
).type(torch.FloatTensor) # 1 means mask and 0 means let though |
|
|
|
objects = torch.cat([expanded_x_self, var_len_input], dim=2) # (b,20,68) |
|
|
|
obj_encoding = self.obs_embeding(objects) # (b,20,64) |
|
|
|
obj_emb = self.obs_embeding(var_len_input) |
|
|
|
objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64) |
|
|
|
obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64) |
|
|
|
self_and_key_emb = torch.cat( |
|
|
|
[x_self.reshape(-1, 1, 64), obj_encoding], dim=1 |
|
|
|
) # (b,21,64) |
|
|
|
key_mask = torch.cat( |
|
|
|
[torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1 |
|
|
|
) # first one is never masked |
|
|
|
# self_and_key_emb = torch.cat( |
|
|
|
# [x_self.reshape(-1, 1, 64), obj_and_self], dim=1 |
|
|
|
# ) # (b,21,64) |
|
|
|
# key_mask = torch.cat( |
|
|
|
# [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1 |
|
|
|
# ) # first one is never masked |
|
|
|
# output, _ = self.attention( |
|
|
|
# self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask |
|
|
|
# ) # (b, 21, 64) |
|
|
|
self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask |
|
|
|
obj_and_self, obj_and_self, obj_and_self, key_mask |
|
|
|
output = self.dense_after_attention(output) + self_and_key_emb |
|
|
|
output = self.dense_after_attention(output) + obj_and_self |
|
|
|
output * (1 - key_mask).reshape(-1, 21, 1), dim=1 |
|
|
|
) / torch.sum( |
|
|
|
output * (1 - key_mask).reshape(-1, 20, 1), dim=1 |
|
|
|
) / (torch.sum( |
|
|
|
) # average pooling |
|
|
|
) + 0.001 ) # average pooling |
|
|
|
encoding = self.linear_encoder(output + x_self) |
|
|
|
encoding = self.linear_encoder(torch.cat([output , x_self], dim=1)) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|