|
|
|
|
|
|
else 0 |
|
|
|
) |
|
|
|
|
|
|
|
self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors( |
|
|
|
self.visual_processors, self.vector_processors, self.attention, encoder_input_size = ModelUtils.create_input_processors( |
|
|
|
observation_shapes, |
|
|
|
self.h_size, |
|
|
|
network_settings.vis_encode_type, |
|
|
|
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
self.self_embedding = LinearEncoder(6, 1, 64) |
|
|
|
self.obs_embeding = LinearEncoder(4, 1, 64) |
|
|
|
self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64) |
|
|
|
|
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = LSTM(self.h_size, self.m_size) |
|
|
|
|
|
|
inputs = torch.cat(encodes + [actions], dim=-1) |
|
|
|
else: |
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
|
|
|
|
# TODO : This is a Hack |
|
|
|
var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
key_mask = torch.sum(var_len_input ** 2, axis=2) < 0.01 # 1 means mask and 0 means let though |
|
|
|
|
|
|
|
self_encoding = processed_vec.reshape(-1, 1, processed_vec.shape[1]) |
|
|
|
self_encoding = self.self_embedding(self_encoding) # (b, 64) |
|
|
|
|
|
|
|
obs_encoding = self.obs_embeding(var_len_input) |
|
|
|
|
|
|
|
expanded_self_encoding = self_encoding.reshape(-1, 1, 64).repeat(1, 20, 1) |
|
|
|
self_and_key_emb = self.self_and_obs_embedding( |
|
|
|
torch.cat([obs_encoding, expanded_self_encoding], dim=2) |
|
|
|
) |
|
|
|
|
|
|
|
# add the self to the entities |
|
|
|
self_and_key_emb = torch.cat([self_encoding, self_and_key_emb], dim=1) |
|
|
|
key_mask = torch.cat([torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1) |
|
|
|
|
|
|
|
output, _ = self.attention(self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask) |
|
|
|
|
|
|
|
output = torch.sum(output * (1 - key_mask).reshape(-1,21,1), dim=1) / torch.sum(1-key_mask, dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
# output = torch.cat([inputs, output], dim=1) |
|
|
|
|
|
|
|
encoding = self.linear_encoder(output) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
if len(vis_inputs) == 0: |
|
|
|
print("O vis obs for some reason") |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|