浏览代码

adding encoding to self and var len input

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
af58753b
共有 1 个文件被更改,包括 17 次插入3 次删除
  1. 20
      ml-agents/mlagents/trainers/torch/networks.py

20
ml-agents/mlagents/trainers/torch/networks.py


entities_sizes = [sensor_specs[idx].shape[1] for idx in var_len_indices]
entities_max_len = [sensor_specs[idx].shape[0] for idx in var_len_indices]
self.x_self_encoder = LinearEncoder(x_self_len, 2, self.h_size // 2)
self.var_len_encoders = torch.nn.ModuleList(
[
LinearEncoder(ent_size, 2, self.h_size // 2)
for ent_size in entities_sizes
]
)
x_self_len, entities_sizes, entities_max_len, self.h_size
self.h_size // 2,
[self.h_size // 2] * len(var_len_indices),
entities_max_len,
self.h_size,
total_enc_size = x_self_len + self.h_size
total_enc_size = self.h_size // 2 + self.h_size
n_layers = max(1, network_settings.num_layers - 2)
else:

if len(var_len_inputs) > 0:
# Some inputs need to be processed with a variable length encoder
masks = EntityEmbeddings.get_masks(var_len_inputs)
qkv = self.entities_embeddings(encoded_self, var_len_inputs)
encoded_self = self.x_self_encoder(encoded_self)
encoded_var_len = [
encoder(x) for encoder, x in zip(self.var_len_encoders, var_len_inputs)
]
qkv = self.entities_embeddings(encoded_self, encoded_var_len)
mu_qkv = torch.mean(qkv, dim=2, keepdim=True)
qkv = (qkv - mu_qkv) / (
torch.sqrt(torch.mean((qkv - mu_qkv) ** 2, dim=2, keepdim=True))

正在加载...
取消
保存