浏览代码

concat x self before attention

/exp-alternate-atten
Andrew Cohen 4 年前
当前提交
84cc2b84
共有 2 个文件被更改,包括 34 次插入23 次删除
  1. 29
      ml-agents/mlagents/trainers/torch/layers.py
  2. 28
      ml-agents/mlagents/trainers/torch/networks.py

29
ml-agents/mlagents/trainers/torch/layers.py


# This is to avoid using .size() when possible as Barracuda does not support
n_q = number_of_queries if number_of_queries != -1 else query.size(1)
n_k = number_of_keys if number_of_keys != -1 else key.size(1)
n_q = 20
n_k = 20
# Create a key mask : Only 1 if all values are 0 # shape = (b, n_k)
# key_mask = torch.sum(key ** 2, axis=2) < 0.01

)
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size)
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor], key_masks: List[torch.Tensor]):
def forward(self, entities: torch.Tensor, key_masks: List[torch.Tensor]):
# Gather the maximum number of entities information
if self.entities_num_max_elements is None:
self.entities_num_max_elements = []

self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entities_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
# .repeat(
# 1, num_entities, 1
# )
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
#for num_entities, ent in zip(self.entities_num_max_elements, entities):
# expanded_self = x_self.reshape(-1, 1, self.self_size)
# # .repeat(
# # 1, num_entities, 1
# # )
# expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
qkv = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
#qkv = torch.cat(
# [ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
# dim=1,
#)
qkv = entities
max_num_ent = sum(self.entities_num_max_elements)
max_num_ent = 20 #sum(self.entities_num_max_elements)
output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent)
# Residual
output = self.residual_layer(output) + qkv

28
ml-agents/mlagents/trainers/torch/networks.py


self.use_fc = False
if not self.use_fc:
emb_size = 32
emb_size = 64
x_self_size=32,
entities_sizes=[32], # hard coded, 4 obs per entity
x_self_size=64,
entities_sizes=[64], # hard coded, 4 obs per entity
self.self_embedding = LinearEncoder(6, 2, 32)
self.obs_embeding = LinearEncoder(4, 2, 32)
#self.self_embedding = LinearEncoder(6, 2, 64)
#self.obs_embeding = LinearEncoder(4, 2, 64)
#6 self (2 coord x 3 stacks) and 4 observable = 10
self.embedding_encoder = LinearEncoder(10, 2, 64)
emb_size + 32, network_settings.num_layers - 1, self.h_size
emb_size, network_settings.num_layers - 1, self.h_size
)
else:
self.linear_encoder = LinearEncoder(

inputs = torch.cat(encodes, dim=-1)
if not self.use_fc:
x_self = self.self_embedding(processed_vec)
#x_self = self.embedding(processed_vec)
processed_var_len_input = self.obs_embeding(var_len_input)
output = self.transformer(x_self, [processed_var_len_input], self.masking_module([var_len_input]))
expanded_self = processed_vec.reshape(-1, 1, 6)
expanded_self = torch.cat([expanded_self] * 20, dim=1)
concat_ent = torch.cat([expanded_self, var_len_input], dim=-1)
processed_var_len_input = self.embedding_encoder(concat_ent)
output = self.transformer(processed_var_len_input, self.masking_module([var_len_input]))
#output = self.transformer(x_self, [processed_var_len_input], self.masking_module([var_len_input]))
# # TODO : This is a Hack
# var_len_input = vis_inputs[0].reshape(-1, 20, 4)

# 1 - key_mask, dim=1, keepdim=True
# ) + 0.001 ) # average pooling
encoding = self.linear_encoder(torch.cat([output, x_self], dim=1))
#encoding = self.linear_encoder(torch.cat([output, x_self], dim=1))
encoding = self.linear_encoder(output)
else:
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1))

正在加载...
取消
保存