浏览代码

different architecture

/exp-bullet-hell-trainer
vincentpierre 4 年前
当前提交
f283cb60
共有 3 个文件被更改,包括 41 次插入27 次删除
  1. 2
      ml-agents/mlagents/trainers/torch/layers.py
  2. 65
      ml-agents/mlagents/trainers/torch/networks.py
  3. 1
      ml-agents/mlagents/trainers/torch/utils.py

2
ml-agents/mlagents/trainers/torch/layers.py


self.entities_num_max_elements: Optional[List[int]] = None
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(self.self_size + ent_size, 1, embedding_size)
LinearEncoder(self.self_size + ent_size, 2, embedding_size)
for ent_size in self.entities_sizes
]
)

65
ml-agents/mlagents/trainers/torch/networks.py


self.use_fc = False
if not self.use_fc:
emb_size = 16
emb_size = 64
x_self_size=16,
entities_sizes=[16], # hard coded, 4 obs per entity
x_self_size=6,
entities_sizes=[4], # hard coded, 4 obs per entity
embedding_size=emb_size,
output_size = self.h_size
)

self.self_embedding = LinearEncoder(6, 2, 16)
self.obs_embeding = LinearEncoder(4, 2, 16)
# self.self_embedding = LinearEncoder(6, 2, 16)
# self.obs_embeding = LinearEncoder(4, 2, 16)
# self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64)
# self.dense_after_attention = LinearEncoder(64, 1, 64)

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
for idx, processor in enumerate(self.vector_processors):
vec_input = vec_inputs[idx]
processed_vec = processor(vec_input)
encodes.append(processed_vec)
# encodes = []
# for idx, processor in enumerate(self.vector_processors):
# vec_input = vec_inputs[idx]
# processed_vec = processor(vec_input)
# encodes.append(processed_vec)
for idx, processor in enumerate(self.visual_processors):
vis_input = vis_inputs[idx]
if not exporting_to_onnx.is_exporting():
vis_input = vis_input.permute([0, 3, 1, 2])
processed_vis = processor(vis_input)
encodes.append(processed_vis)
# for idx, processor in enumerate(self.visual_processors):
# vis_input = vis_inputs[idx]
# if not exporting_to_onnx.is_exporting():
# vis_input = vis_input.permute([0, 3, 1, 2])
# processed_vis = processor(vis_input)
# encodes.append(processed_vis)
if len(encodes) == 0:
raise Exception("No valid inputs to network.")
# if len(encodes) == 0:
# raise Exception("No valid inputs to network.")
if actions is not None:
inputs = torch.cat(encodes + [actions], dim=-1)
else:
inputs = torch.cat(encodes, dim=-1)
# if actions is not None:
# inputs = torch.cat(encodes + [actions], dim=-1)
# else:
# inputs = torch.cat(encodes, dim=-1)
x_self = self.self_embedding(processed_vec)
# x_self = self.self_embedding(vec_inputs[0])
# print(vis_inputs[0].shape)
processed_var_len_input = self.obs_embeding(var_len_input)
# processed_var_len_input = self.obs_embeding(var_len_input)
output = self.transformer(x_self, [processed_var_len_input], masks)
# if exporting_to_onnx.is_exporting():
# tmp = var_len_input.reshape(-1, 1, 20, 4)
# # permute with `permute([0,2,3,1])`
# tmp = tmp.permute([0, 2, 1, 3]) # (b, h, emb, n_k)
# tmp -= 1
# tmp += 1
# tmp = tmp.permute([0, 1, 3, 2]) # (b, h, emb, n_k)
# tmp = tmp.reshape(-1, 20, 4)
# masks = SimpleTransformer.get_masks([tmp])
output = self.transformer(vec_inputs[0], [var_len_input], masks)
# # TODO : This is a Hack
# var_len_input = vis_inputs[0].reshape(-1, 20, 4)

encoding = output
else:
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1))
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), vec_inputs[0]], dim=1))
# encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1))
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

1
ml-agents/mlagents/trainers/torch/utils.py


# Total output size for all inputs + CNNs
total_processed_size = vector_size + visual_output_size
print(observation_shapes)
return (
nn.ModuleList(visual_encoders),
nn.ModuleList(vector_encoders),

正在加载...
取消
保存