浏览代码

Fixing some issues

/exp-alternate-atten
vincentpierre 4 年前
当前提交
9b3d4ade
共有 2 个文件被更改,包括 28 次插入24 次删除
  1. 6
      Project/Assets/ML-Agents/Examples/Bullet/Scripts/DodgeAgent.cs
  2. 46
      ml-agents/mlagents/trainers/torch/networks.py

6
Project/Assets/ML-Agents/Examples/Bullet/Scripts/DodgeAgent.cs


/// </summary>
public override void OnEpisodeBegin()
{
var rotation = Random.Range(0, 4);
var rotationAngle = rotation * 90f;
area.transform.Rotate(new Vector3(0f, rotationAngle, 0f));
// var rotation = Random.Range(0, 4);
// var rotationAngle = rotation * 90f;
// area.transform.Rotate(new Vector3(0f, rotationAngle, 0f));
transform.position = GetRandomSpawnPos();//
m_AgentRb.velocity = Vector3.zero;

46
ml-agents/mlagents/trainers/torch/networks.py


else 0
)
self.visual_processors, self.vector_processors, self.attention, encoder_input_size = ModelUtils.create_input_processors(
self.visual_processors, self.vector_processors, self.attention, _ = ModelUtils.create_input_processors(
observation_shapes,
self.h_size,
network_settings.vis_encode_type,

total_enc_size = encoder_input_size + encoded_act_size
# total_enc_size = encoder_input_size + encoded_act_size
self.self_embedding = LinearEncoder(6, 1, 64)
self.obs_embeding = LinearEncoder(4 + 64, 1, 64)
self.self_embedding = LinearEncoder(6, 2, 64)
self.obs_embeding = LinearEncoder(4, 2, 64)
total_enc_size, network_settings.num_layers, self.h_size
64 * 2, network_settings.num_layers, self.h_size
)
if self.use_lstm:

# TODO : This is a Hack
var_len_input = vis_inputs[0].reshape(-1, 20, 4)
key_mask = (
key_mask = (
) # 1 means mask and 0 means let though
).type(torch.FloatTensor) # 1 means mask and 0 means let though
objects = torch.cat([expanded_x_self, var_len_input], dim=2) # (b,20,68)
obj_encoding = self.obs_embeding(objects) # (b,20,64)
obj_emb = self.obs_embeding(var_len_input)
objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64)
obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64)
self_and_key_emb = torch.cat(
[x_self.reshape(-1, 1, 64), obj_encoding], dim=1
) # (b,21,64)
key_mask = torch.cat(
[torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1
) # first one is never masked
# self_and_key_emb = torch.cat(
# [x_self.reshape(-1, 1, 64), obj_and_self], dim=1
# ) # (b,21,64)
# key_mask = torch.cat(
# [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1
# ) # first one is never masked
# output, _ = self.attention(
# self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask
# ) # (b, 21, 64)
self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask
obj_and_self, obj_and_self, obj_and_self, key_mask
output = self.dense_after_attention(output) + self_and_key_emb
output = self.dense_after_attention(output) + obj_and_self
output * (1 - key_mask).reshape(-1, 21, 1), dim=1
) / torch.sum(
output * (1 - key_mask).reshape(-1, 20, 1), dim=1
) / (torch.sum(
) # average pooling
) + 0.001 ) # average pooling
encoding = self.linear_encoder(output + x_self)
encoding = self.linear_encoder(torch.cat([output , x_self], dim=1))
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

正在加载...
取消
保存