浏览代码

Trainer with attention

/exp-alternate-atten
vincentpierre 4 年前
当前提交
7ef3c9a1
共有 4 个文件被更改,包括 22 次插入12 次删除
  1. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 6
      ml-agents/mlagents/trainers/torch/layers.py
  3. 18
      ml-agents/mlagents/trainers/torch/networks.py
  4. 8
      ml-agents/mlagents/trainers/torch/utils.py

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


# vis_obs.append(vis_ob)
# else:
# vis_obs = []
vis_obs = [ ModelUtils.list_to_tensor(batch["visual_obs%d" % 0])]
vis_obs = [ModelUtils.list_to_tensor(batch["visual_obs%d" % 0])]
log_probs, entropy, values = self.policy.evaluate_actions(
vec_obs,
vis_obs,

6
ml-agents/mlagents/trainers/torch/layers.py


)
def forward(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_mask:torch.Tensor
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
b, n_q, n_k = query.size(0), query.size(1), key.size(1)

18
ml-agents/mlagents/trainers/torch/networks.py


self.obs_embeding = LinearEncoder(4, 1, 64)
self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
else:

# TODO : This is a Hack
var_len_input = vis_inputs[0].reshape(-1, 20, 4)
key_mask = torch.sum(var_len_input ** 2, axis=2) < 0.01 # 1 means mask and 0 means let though
key_mask = (
torch.sum(var_len_input ** 2, axis=2) < 0.01
) # 1 means mask and 0 means let though
self_encoding = processed_vec.reshape(-1, 1, processed_vec.shape[1])
self_encoding = self.self_embedding(self_encoding) # (b, 64)

# add the self to the entities
self_and_key_emb = torch.cat([self_encoding, self_and_key_emb], dim=1)
key_mask = torch.cat([torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1)
key_mask = torch.cat(
[torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1
)
output, _ = self.attention(self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask)
output = torch.sum(output * (1 - key_mask).reshape(-1,21,1), dim=1) / torch.sum(1-key_mask, dim=1, keepdim=True)
output, _ = self.attention(
self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask
)
output = torch.sum(
output * (1 - key_mask).reshape(-1, 21, 1), dim=1
) / torch.sum(1 - key_mask, dim=1, keepdim=True)
# output = torch.cat([inputs, output], dim=1)

8
ml-agents/mlagents/trainers/torch/utils.py


elif len(dimension) == 1:
vector_size += dimension[0]
else:
print(#raise UnityTrainerException(
print( # raise UnityTrainerException(
f"Unsupported shape of {dimension} for observation {i}"
)
if vector_size > 0:

max_observables, observable_size, output_size = (20, 4, 64)
attention = MultiHeadAttention(
query_size=output_size,
key_size= output_size,
key_size=output_size,
embedding_size=64
embedding_size=64,
)
return (

output_size#total_processed_size + output_size,
output_size, # total_processed_size + output_size,
)
@staticmethod

正在加载...
取消
保存