vincentpierre
4 年前
当前提交
8cb050ef
共有 38 个文件被更改,包括 777 次插入 和 119 次删除
-
6com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
-
33com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
-
6gym-unity/gym_unity/envs/__init__.py
-
4gym-unity/gym_unity/tests/test_gym.py
-
57ml-agents-envs/mlagents_envs/base_env.py
-
19ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
-
6ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
-
11ml-agents-envs/mlagents_envs/rpc_utils.py
-
14ml-agents-envs/mlagents_envs/tests/test_envs.py
-
2ml-agents-envs/mlagents_envs/tests/test_registry.py
-
31ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
-
7ml-agents-envs/mlagents_envs/tests/test_steps.py
-
8ml-agents/mlagents/trainers/demo_loader.py
-
6ml-agents/mlagents/trainers/policy/policy.py
-
2ml-agents/mlagents/trainers/policy/torch_policy.py
-
4ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
14ml-agents/mlagents/trainers/tests/mock_brain.py
-
10ml-agents/mlagents/trainers/tests/simple_test_envs.py
-
4ml-agents/mlagents/trainers/tests/test_demo_loader.py
-
6ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
-
2ml-agents/mlagents/trainers/tests/torch/test_ppo.py
-
40ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
-
14ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
28ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
-
28ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py
-
4ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
-
2ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
2ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
2ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
-
4ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
-
4ml-agents/mlagents/trainers/torch/model_serialization.py
-
4ml-agents/tests/yamato/scripts/run_llapi.py
-
1protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
-
84com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
-
27com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
-
47com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs
-
162ml-agents/mlagents/trainers/tests/torch/test_attention.py
-
191ml-agents/mlagents/trainers/torch/attention.py
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
public class BufferSensor : ISensor |
|||
{ |
|||
private int m_MaxNumObs; |
|||
private int m_ObsSize; |
|||
float[] m_ObservationBuffer; |
|||
int m_CurrentNumObservables; |
|||
public BufferSensor(int maxNumberObs, int obsSize) |
|||
{ |
|||
m_MaxNumObs = maxNumberObs; |
|||
m_ObsSize = obsSize; |
|||
m_ObservationBuffer = new float[m_ObservableSize * m_MaxNumObservables]; |
|||
m_CurrentNumObservables = 0; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return new int[] { m_MaxNumObs, m_ObsSize }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Appends an observation to the buffer. If the buffer is full (maximum number
|
|||
/// of observation is reached) the observation will be ignored. the length of
|
|||
/// the provided observation array must be equal to the observation size of
|
|||
/// the buffer sensor.
|
|||
/// </summary>
|
|||
/// <param name="obs"> The float array observation</param>
|
|||
public void AppendObservation(float[] obs) |
|||
{ |
|||
if (m_CurrentNumObservables >= m_MaxNumObs) |
|||
{ |
|||
return; |
|||
} |
|||
for (int i = 0; i < obs.Length; i++) |
|||
{ |
|||
m_ObservationBuffer[m_CurrentNumObservables * m_MaxNumObs + i] = obs[i]; |
|||
} |
|||
m_CurrentNumObservables++; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int Write(ObservationWriter writer) |
|||
{ |
|||
for (int i = 0; i < m_ObservableSize * m_MaxNumObservables; i++) |
|||
{ |
|||
writer[i] = m_ObservationBuffer[i]; |
|||
} |
|||
return m_ObservableSize * m_MaxNumObservables; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public virtual byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() |
|||
{ |
|||
Reset(); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Reset() |
|||
{ |
|||
m_CurrentNumObservables = 0; |
|||
Array.Clear(m_ObservationBuffer, 0, m_ObservationBuffer.Length); |
|||
} |
|||
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
public string GetName() |
|||
{ |
|||
return "BufferSensor"; |
|||
} |
|||
|
|||
} |
|||
|
|||
} |
|
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// A component for BufferSensor.
|
|||
/// </summary>
|
|||
[AddComponentMenu("ML Agents/Buffer Sensor", (int)MenuGroup.Sensors)] |
|||
public class BufferSensorComponent : SensorComponent |
|||
{ |
|||
public int ObservableSize; |
|||
public int MaxNumObservables; |
|||
|
|||
/// <inheritdoc/>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
return new BufferSensor(ObservableSize, MaxNumObservables); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
return new[] { MaxNumObservables, ObservableSize }; |
|||
} |
|||
} |
|||
} |
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// The Dimension property flags of the observations
|
|||
/// </summary>
|
|||
[System.Flags] |
|||
internal enum DimensionProperty |
|||
{ |
|||
/// <summary>
|
|||
/// No properties specified.
|
|||
/// </summary>
|
|||
Unspecified = 0, |
|||
|
|||
/// <summary>
|
|||
/// No Property of the observation in that dimension. Observation can be processed with
|
|||
/// Fully connected networks.
|
|||
/// </summary>
|
|||
None = 1, |
|||
|
|||
/// <summary>
|
|||
/// Means it is possible to do a convolution in this dimension.
|
|||
/// </summary>
|
|||
TranslationalEquivariance = 2, |
|||
|
|||
/// <summary>
|
|||
/// Means that there can be a variable number of observations in this dimension.
|
|||
/// The observations are unordered.
|
|||
/// </summary>
|
|||
VariableSize = 3, |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Sensor interface for sensors with special dimension properties.
|
|||
/// </summary>
|
|||
internal interface IDimensionPropertiesSensor : ISensor |
|||
{ |
|||
/// <summary>
|
|||
/// Returns the array containing the properties of each dimensions of the
|
|||
/// observation. The length of the array must be equal to the rank of the
|
|||
/// observation tensor.
|
|||
/// </summary>
|
|||
/// <returns>The array of DimensionProperty</returns>
|
|||
DimensionProperty[] GetDimensionProperties(); |
|||
} |
|||
} |
|
|||
from mlagents.torch_utils import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.torch.layers import linear_layer |
|||
from mlagents.trainers.torch.attention import MultiHeadAttention, SimpleTransformer |
|||
|
|||
|
|||
def test_multi_head_attention_initialization(): |
|||
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12 |
|||
n_k, n_q, b = 13, 14, 15 |
|||
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size) |
|||
|
|||
query = torch.ones((b, n_q, q_size)) |
|||
key = torch.ones((b, n_k, k_size)) |
|||
value = torch.ones((b, n_k, v_size)) |
|||
|
|||
output, attention = mha.forward(query, key, value) |
|||
|
|||
assert output.shape == (b, n_q, o_size) |
|||