|
|
|
|
|
|
|
|
|
|
emb_size = 64 |
|
|
|
self.transformer = SimpleTransformer( |
|
|
|
x_self_size=encoder_input_size, |
|
|
|
entities_sizes=[4], # hard coded, 4 obs per entity |
|
|
|
x_self_size=64, |
|
|
|
entities_sizes=[64], # hard coded, 4 obs per entity |
|
|
|
# self.self_embedding = LinearEncoder(6, 2, 64) |
|
|
|
# self.obs_embeding = LinearEncoder(4, 2, 64) |
|
|
|
self.self_embedding = LinearEncoder(6, 2, 64) |
|
|
|
self.obs_embeding = LinearEncoder(4, 2, 64) |
|
|
|
emb_size + encoder_input_size, network_settings.num_layers, self.h_size |
|
|
|
64 + 64, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
|
|
|
|
x_self = processed_vec.reshape(-1, processed_vec.shape[1]) |
|
|
|
x_self = self.self_embedding(x_self) |
|
|
|
var_len_input = self.obs_embeding(var_len_input) |
|
|
|
output = self.transformer(x_self, [var_len_input]) |
|
|
|
|
|
|
|
# # TODO : This is a Hack |
|
|
|