|
|
|
|
|
|
self.use_fc = False |
|
|
|
|
|
|
|
if not self.use_fc: |
|
|
|
emb_size = 16 |
|
|
|
emb_size = 64 |
|
|
|
x_self_size=16, |
|
|
|
entities_sizes=[16], # hard coded, 4 obs per entity |
|
|
|
x_self_size=6, |
|
|
|
entities_sizes=[4], # hard coded, 4 obs per entity |
|
|
|
embedding_size=emb_size, |
|
|
|
output_size = self.h_size |
|
|
|
) |
|
|
|
|
|
|
self.self_embedding = LinearEncoder(6, 2, 16) |
|
|
|
self.obs_embeding = LinearEncoder(4, 2, 16) |
|
|
|
# self.self_embedding = LinearEncoder(6, 2, 16) |
|
|
|
# self.obs_embeding = LinearEncoder(4, 2, 16) |
|
|
|
# self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64) |
|
|
|
# self.dense_after_attention = LinearEncoder(64, 1, 64) |
|
|
|
|
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.vector_processors): |
|
|
|
vec_input = vec_inputs[idx] |
|
|
|
processed_vec = processor(vec_input) |
|
|
|
encodes.append(processed_vec) |
|
|
|
# encodes = [] |
|
|
|
# for idx, processor in enumerate(self.vector_processors): |
|
|
|
# vec_input = vec_inputs[idx] |
|
|
|
# processed_vec = processor(vec_input) |
|
|
|
# encodes.append(processed_vec) |
|
|
|
for idx, processor in enumerate(self.visual_processors): |
|
|
|
vis_input = vis_inputs[idx] |
|
|
|
if not exporting_to_onnx.is_exporting(): |
|
|
|
vis_input = vis_input.permute([0, 3, 1, 2]) |
|
|
|
processed_vis = processor(vis_input) |
|
|
|
encodes.append(processed_vis) |
|
|
|
# for idx, processor in enumerate(self.visual_processors): |
|
|
|
# vis_input = vis_inputs[idx] |
|
|
|
# if not exporting_to_onnx.is_exporting(): |
|
|
|
# vis_input = vis_input.permute([0, 3, 1, 2]) |
|
|
|
# processed_vis = processor(vis_input) |
|
|
|
# encodes.append(processed_vis) |
|
|
|
if len(encodes) == 0: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
# if len(encodes) == 0: |
|
|
|
# raise Exception("No valid inputs to network.") |
|
|
|
if actions is not None: |
|
|
|
inputs = torch.cat(encodes + [actions], dim=-1) |
|
|
|
else: |
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
# if actions is not None: |
|
|
|
# inputs = torch.cat(encodes + [actions], dim=-1) |
|
|
|
# else: |
|
|
|
# inputs = torch.cat(encodes, dim=-1) |
|
|
|
x_self = self.self_embedding(processed_vec) |
|
|
|
# x_self = self.self_embedding(vec_inputs[0]) |
|
|
|
# print(vis_inputs[0].shape) |
|
|
|
processed_var_len_input = self.obs_embeding(var_len_input) |
|
|
|
# processed_var_len_input = self.obs_embeding(var_len_input) |
|
|
|
output = self.transformer(x_self, [processed_var_len_input], masks) |
|
|
|
# if exporting_to_onnx.is_exporting(): |
|
|
|
# tmp = var_len_input.reshape(-1, 1, 20, 4) |
|
|
|
# # permute with `permute([0,2,3,1])` |
|
|
|
# tmp = tmp.permute([0, 2, 1, 3]) # (b, h, emb, n_k) |
|
|
|
# tmp -= 1 |
|
|
|
# tmp += 1 |
|
|
|
# tmp = tmp.permute([0, 1, 3, 2]) # (b, h, emb, n_k) |
|
|
|
# tmp = tmp.reshape(-1, 20, 4) |
|
|
|
# masks = SimpleTransformer.get_masks([tmp]) |
|
|
|
|
|
|
|
output = self.transformer(vec_inputs[0], [var_len_input], masks) |
|
|
|
|
|
|
|
# # TODO : This is a Hack |
|
|
|
# var_len_input = vis_inputs[0].reshape(-1, 20, 4) |
|
|
|
|
|
|
|
|
|
|
encoding = output |
|
|
|
else: |
|
|
|
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1)) |
|
|
|
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), vec_inputs[0]], dim=1)) |
|
|
|
# encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1)) |
|
|
|
|
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|