|
|
|
|
|
|
self.use_fc = False |
|
|
|
|
|
|
|
if not self.use_fc: |
|
|
|
emb_size = 32 |
|
|
|
emb_size = 64 |
|
|
|
x_self_size=32, |
|
|
|
entities_sizes=[32], # hard coded, 4 obs per entity |
|
|
|
x_self_size=64, |
|
|
|
entities_sizes=[64], # hard coded, 4 obs per entity |
|
|
|
self.self_embedding = LinearEncoder(6, 2, 32) |
|
|
|
self.obs_embeding = LinearEncoder(4, 2, 32) |
|
|
|
#self.self_embedding = LinearEncoder(6, 2, 64) |
|
|
|
#self.obs_embeding = LinearEncoder(4, 2, 64) |
|
|
|
#6 self (2 coord x 3 stacks) and 4 observable = 10 |
|
|
|
self.embedding_encoder = LinearEncoder(10, 2, 64) |
|
|
|
|
|
|
|
emb_size + 32, network_settings.num_layers - 1, self.h_size |
|
|
|
emb_size, network_settings.num_layers - 1, self.h_size |
|
|
|
) |
|
|
|
else: |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
|
|
|
|
if not self.use_fc: |
|
|
|
x_self = self.self_embedding(processed_vec) |
|
|
|
#x_self = self.embedding(processed_vec) |
|
|
|
processed_var_len_input = self.obs_embeding(var_len_input) |
|
|
|
output = self.transformer(x_self, [processed_var_len_input], self.masking_module([var_len_input])) |
|
|
|
expanded_self = processed_vec.reshape(-1, 1, 6) |
|
|
|
expanded_self = torch.cat([expanded_self] * 20, dim=1) |
|
|
|
concat_ent = torch.cat([expanded_self, var_len_input], dim=-1) |
|
|
|
processed_var_len_input = self.embedding_encoder(concat_ent) |
|
|
|
output = self.transformer(processed_var_len_input, self.masking_module([var_len_input])) |
|
|
|
#output = self.transformer(x_self, [processed_var_len_input], self.masking_module([var_len_input])) |
|
|
|
|
|
|
|
# # TODO : This is a Hack |
|
|
|
# var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
|
|
|
# 1 - key_mask, dim=1, keepdim=True |
|
|
|
# ) + 0.001 ) # average pooling |
|
|
|
|
|
|
|
encoding = self.linear_encoder(torch.cat([output, x_self], dim=1)) |
|
|
|
#encoding = self.linear_encoder(torch.cat([output, x_self], dim=1)) |
|
|
|
encoding = self.linear_encoder(output) |
|
|
|
else: |
|
|
|
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1)) |
|
|
|
|
|
|
|