浏览代码

additional changes

/exp-alternate-atten
vincentpierre 4 年前
当前提交
c264b4da
共有 4 个文件被更改,包括 17 次插入17 次删除
  1. 4
      Project/Assets/ML-Agents/Examples/Bullet/Prefabs/BulletArea.prefab
  2. 5
      Project/Assets/ML-Agents/Examples/Bullet/Scripts/AttentionSensorComponent.cs
  3. 4
      Project/Assets/ML-Agents/Examples/Bullet/Scripts/DodgeAgent.cs
  4. 21
      ml-agents/mlagents/trainers/torch/networks.py

4
Project/Assets/ML-Agents/Examples/Bullet/Prefabs/BulletArea.prefab


- target: {fileID: 5229809530250264056, guid: 7a515d80a9d5b4cfaba55aa066594016,
type: 3}
propertyPath: m_IsActive
value: 1
value: 0
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 7a515d80a9d5b4cfaba55aa066594016, type: 3}

- target: {fileID: 5229809530250264056, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: m_IsActive
value: 0
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 9204493211719777360, guid: 16266ae2c040142b1a996878d94ec3fb, type: 3}

5
Project/Assets/ML-Agents/Examples/Bullet/Scripts/AttentionSensorComponent.cs


var bullets = m_AgentTransform.parent.GetComponentsInChildren<Bullet>();
// Sort by closest :
Array.Sort(bullets , (a, b) => Vector3.Distance(a.transform.position, m_AgentTransform.position) - Vector3.Distance(b.transform.position, m_AgentTransform.position) > 0 ? 1 : -1);
m_ObservationBuffer[m_CurrentNumObservables * m_ObservableSize + 0] = b.transform.position.x - m_AgentTransform.position.x;
m_ObservationBuffer[m_CurrentNumObservables * m_ObservableSize + 1] = b.transform.position.z - m_AgentTransform.position.z;
m_ObservationBuffer[m_CurrentNumObservables * m_ObservableSize + 0] = (b.transform.position.x - m_AgentTransform.parent.position.x) / 10f;
m_ObservationBuffer[m_CurrentNumObservables * m_ObservableSize + 1] = (b.transform.position.z - m_AgentTransform.parent.position.z) / 10f;
m_ObservationBuffer[m_CurrentNumObservables * m_ObservableSize + 2] = b.transform.forward.x;
m_ObservationBuffer[m_CurrentNumObservables * m_ObservableSize + 3] = b.transform.forward.z;
m_CurrentNumObservables += 1;

4
Project/Assets/ML-Agents/Examples/Bullet/Scripts/DodgeAgent.cs


public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(transform.position.x - area.transform.position.x);
sensor.AddObservation(transform.position.z - area.transform.position.z);
sensor.AddObservation((transform.position.x - area.transform.position.x) / 10f);
sensor.AddObservation((transform.position.z - area.transform.position.z) / 10f);
}
/// <summary>

21
ml-agents/mlagents/trainers/torch/networks.py


normalize=self.normalize,
)
total_enc_size = encoder_input_size + encoded_act_size
self.self_embedding = LinearEncoder(6, 1, 64)

# TODO : This is a Hack
var_len_input = vis_inputs[0].reshape(-1, 20, 4)
key_mask = 0 * (
key_mask = (
torch.sum(var_len_input ** 2, axis=2) < 0.01
) # 1 means mask and 0 means let though

objects = torch.cat([expanded_x_self, var_len_input], dim=2) #(b,20,68)
objects = torch.cat([expanded_x_self, var_len_input], dim=2) # (b,20,68)
obj_encoding = self.obs_embeding(objects)#(b,20,64)
obj_encoding = self.obs_embeding(objects) # (b,20,64)
self_and_key_emb = torch.cat([x_self.reshape(-1,1,64), obj_encoding], dim=1) #(b,21,64)
self_and_key_emb = torch.cat(
[x_self.reshape(-1, 1, 64), obj_encoding], dim=1
) # (b,21,64)
) # first one is never masked
) # first one is never masked
output, _ = self.attention(
self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask

) / torch.sum(1 - key_mask, dim=1, keepdim=True) # average pooling
) / torch.sum(
1 - key_mask, dim=1, keepdim=True
) # average pooling
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

正在加载...
取消
保存