浏览代码

Bigger network, needed to solve

/exp-alternate-atten
vincentpierre 4 年前
当前提交
58f38662
共有 1 个文件被更改,包括 7 次插入5 次删除
  1. 12
      ml-agents/mlagents/trainers/torch/networks.py

12
ml-agents/mlagents/trainers/torch/networks.py


emb_size = 64
self.transformer = SimpleTransformer(
x_self_size=encoder_input_size,
entities_sizes=[4], # hard coded, 4 obs per entity
x_self_size=64,
entities_sizes=[64], # hard coded, 4 obs per entity
# self.self_embedding = LinearEncoder(6, 2, 64)
# self.obs_embeding = LinearEncoder(4, 2, 64)
self.self_embedding = LinearEncoder(6, 2, 64)
self.obs_embeding = LinearEncoder(4, 2, 64)
emb_size + encoder_input_size, network_settings.num_layers, self.h_size
64 + 64, network_settings.num_layers, self.h_size
)
if self.use_lstm:

inputs = torch.cat(encodes, dim=-1)
x_self = processed_vec.reshape(-1, processed_vec.shape[1])
x_self = self.self_embedding(x_self)
var_len_input = self.obs_embeding(var_len_input)
output = self.transformer(x_self, [var_len_input])
# # TODO : This is a Hack

正在加载...
取消
保存