|
|
|
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, SimpleTransformer |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, SimpleTransformer, ZeroObservationMask |
|
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx |
|
|
|
|
|
|
|
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
|
|
|
|
emb_size = 64 |
|
|
|
self.transformer = SimpleTransformer( |
|
|
|
x_self_size=64, |
|
|
|
entities_sizes=[64], # hard coded, 4 obs per entity |
|
|
|
embedding_size=emb_size, |
|
|
|
) |
|
|
|
self.use_fc = False |
|
|
|
|
|
|
|
if not self.use_fc: |
|
|
|
emb_size = 32 |
|
|
|
|
|
|
|
self.masking_module = ZeroObservationMask() |
|
|
|
self.transformer = SimpleTransformer( |
|
|
|
x_self_size=32, |
|
|
|
entities_sizes=[32], # hard coded, 4 obs per entity |
|
|
|
embedding_size=emb_size, |
|
|
|
) |
|
|
|
# total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
# total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
self.self_embedding = LinearEncoder(6, 2, 64) |
|
|
|
self.obs_embeding = LinearEncoder(4, 2, 64) |
|
|
|
# self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64) |
|
|
|
# self.dense_after_attention = LinearEncoder(64, 1, 64) |
|
|
|
self.self_embedding = LinearEncoder(6, 2, 32) |
|
|
|
self.obs_embeding = LinearEncoder(4, 2, 32) |
|
|
|
# self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64) |
|
|
|
# self.dense_after_attention = LinearEncoder(64, 1, 64) |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
64 + 64, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
emb_size + 32, network_settings.num_layers - 1, self.h_size |
|
|
|
) |
|
|
|
else: |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
6 + 4 * 20, network_settings.num_layers + 2, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = LSTM(self.h_size, self.m_size) |
|
|
|
|
|
|
else: |
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
|
|
|
|
x_self = processed_vec.reshape(-1, processed_vec.shape[1]) |
|
|
|
x_self = self.self_embedding(x_self) |
|
|
|
var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
var_len_input = self.obs_embeding(var_len_input) |
|
|
|
output = self.transformer(x_self, [var_len_input]) |
|
|
|
if not self.use_fc: |
|
|
|
x_self = self.self_embedding(processed_vec) |
|
|
|
var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
processed_var_len_input = self.obs_embeding(var_len_input) |
|
|
|
output = self.transformer(x_self, [processed_var_len_input], self.masking_module([var_len_input])) |
|
|
|
# # TODO : This is a Hack |
|
|
|
# var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
# key_mask = ( |
|
|
|
# torch.sum(var_len_input ** 2, axis=2) < 0.01 |
|
|
|
# ).type(torch.FloatTensor) # 1 means mask and 0 means let though |
|
|
|
# # TODO : This is a Hack |
|
|
|
# var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
# key_mask = ( |
|
|
|
# torch.sum(var_len_input ** 2, axis=2) < 0.01 |
|
|
|
# ).type(torch.FloatTensor) # 1 means mask and 0 means let though |
|
|
|
# x_self = processed_vec.reshape(-1, processed_vec.shape[1]) |
|
|
|
# x_self = self.self_embedding(x_self) # (b, 1,64) |
|
|
|
# expanded_x_self = x_self.reshape(-1, 1, 64).repeat(1, 20, 1) |
|
|
|
# x_self = processed_vec.reshape(-1, processed_vec.shape[1]) |
|
|
|
# x_self = self.self_embedding(x_self) # (b, 1,64) |
|
|
|
# expanded_x_self = x_self.reshape(-1, 1, 64).repeat(1, 20, 1) |
|
|
|
# obj_emb = self.obs_embeding(var_len_input) |
|
|
|
# objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64) |
|
|
|
# obj_emb = self.obs_embeding(var_len_input) |
|
|
|
# objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64) |
|
|
|
# obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64) |
|
|
|
# # add the self to the entities |
|
|
|
# # self_and_key_emb = torch.cat( |
|
|
|
# # [x_self.reshape(-1, 1, 64), obj_and_self], dim=1 |
|
|
|
# # ) # (b,21,64) |
|
|
|
# # key_mask = torch.cat( |
|
|
|
# # [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1 |
|
|
|
# # ) # first one is never masked |
|
|
|
# obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64) |
|
|
|
# # add the self to the entities |
|
|
|
# # self_and_key_emb = torch.cat( |
|
|
|
# # [x_self.reshape(-1, 1, 64), obj_and_self], dim=1 |
|
|
|
# # ) # (b,21,64) |
|
|
|
# # key_mask = torch.cat( |
|
|
|
# # [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1 |
|
|
|
# # ) # first one is never masked |
|
|
|
# # output, _ = self.attention( |
|
|
|
# # self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask |
|
|
|
# # ) # (b, 21, 64) |
|
|
|
# output, _ = self.attention( |
|
|
|
# obj_and_self, obj_and_self, obj_and_self, key_mask |
|
|
|
# ) # (b, 21, 64) |
|
|
|
# output = self.dense_after_attention(output) + obj_and_self |
|
|
|
# # output, _ = self.attention( |
|
|
|
# # self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask |
|
|
|
# # ) # (b, 21, 64) |
|
|
|
# output, _ = self.attention( |
|
|
|
# obj_and_self, obj_and_self, obj_and_self, key_mask |
|
|
|
# ) # (b, 21, 64) |
|
|
|
# output = self.dense_after_attention(output) + obj_and_self |
|
|
|
# output = torch.sum( |
|
|
|
# output * (1 - key_mask).reshape(-1, 20, 1), dim=1 |
|
|
|
# ) / (torch.sum( |
|
|
|
# 1 - key_mask, dim=1, keepdim=True |
|
|
|
# ) + 0.001 ) # average pooling |
|
|
|
# output = torch.sum( |
|
|
|
# output * (1 - key_mask).reshape(-1, 20, 1), dim=1 |
|
|
|
# ) / (torch.sum( |
|
|
|
# 1 - key_mask, dim=1, keepdim=True |
|
|
|
# ) + 0.001 ) # average pooling |
|
|
|
encoding = self.linear_encoder(torch.cat([output, x_self], dim=1)) |
|
|
|
encoding = self.linear_encoder(torch.cat([output, x_self], dim=1)) |
|
|
|
else: |
|
|
|
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1)) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|