浏览代码

Support multi-dimensional and compressed observations stacking (#4476)

Added stacking to multi-dimensional and compressed observations and added compressed channel mapping in communicator to support decompression.

Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>
Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
c3d2b902
共有 28 个文件被更改,包括 832 次插入97 次删除
  1. 12
      com.unity.ml-agents/CHANGELOG.md
  2. 1
      com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
  3. 1
      com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs
  4. 6
      com.unity.ml-agents/Runtime/Academy.cs
  5. 62
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  6. 4
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  7. 40
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  8. 34
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
  9. 54
      com.unity.ml-agents/Runtime/SensorHelper.cs
  10. 28
      com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
  11. 28
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
  12. 190
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  13. 31
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  14. 156
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  15. 10
      com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
  16. 11
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  17. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  18. 19
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
  19. 6
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
  20. 4
      ml-agents-envs/mlagents_envs/environment.py
  21. 81
      ml-agents-envs/mlagents_envs/rpc_utils.py
  22. 76
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  23. 3
      protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
  24. 1
      protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
  25. 20
      com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs
  26. 11
      com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta
  27. 23
      com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs
  28. 11
      com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta

12
com.unity.ml-agents/CHANGELOG.md


- Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch
trainers. To use RND, add a `rnd` section to the `reward_signals` section of your
yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward)
- Stacking for compressed observations is now supported. An addtional setting
option `Observation Stacks` is added in editor to sensor components that support
compressed observations. A new class `ISparseChannelSensor` with an
additional method `GetCompressedChannelMapping()`is added to generate a mapping
of the channels in compressed data to the actual channel after decompression,
for the python side to decompress correctly. (#4476)
- The Communication API was changed to 1.2.0 to indicate support for stacked
compressed observation. A new entry `compressed_channel_mapping` is added to the
proto to handle decompression correctly. Newer versions of the package that wish to
make use of this will also need a compatible version of the Python trainers. (#4476)
### Bug Fixes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)

1
com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs


EditorGUILayout.PropertyField(so.FindProperty("m_Width"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);

1
com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs


EditorGUILayout.PropertyField(so.FindProperty("m_RenderTexture"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
}
EditorGUI.EndDisabledGroup();

6
com.unity.ml-agents/Runtime/Academy.cs


/// <term>1.1.0</term>
/// <description>Support concatenated PNGs for compressed observations.</description>
/// </item>
/// <item>
/// <term>1.2.0</term>
/// <description>Support compression mapping for stacked compressed observations.</description>
/// </item>
const string k_ApiVersion = "1.1.0";
const string k_ApiVersion = "1.2.0";
/// <summary>
/// Unity package version of com.unity.ml-agents.

62
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


/// <summary>
/// Static flag to make sure that we only fire the warning once.
/// </summary>
private static bool s_HaveWarnedAboutTrainerCapabilities = false;
private static bool s_HaveWarnedTrainerCapabilitiesMultiPng = false;
private static bool s_HaveWarnedTrainerCapabilitiesMapping = false;
/// <summary>
/// Generate an ObservationProto for the sensor using the provided ObservationWriter.

var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
if (!trainerCanHandle)
{
if (!s_HaveWarnedAboutTrainerCapabilities)
if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
s_HaveWarnedAboutTrainerCapabilities = true;
s_HaveWarnedTrainerCapabilitiesMultiPng = true;
}
compressionType = SensorCompressionType.None;
}
}
// Check capabilities if we need mapping for compressed observations
if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
{
var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
var isTrivialMapping = IsTrivialMapping(sensor);
if (!trainerCanHandleMapping && !isTrivialMapping)
{
if (!s_HaveWarnedTrainerCapabilitiesMapping)
{
Debug.LogWarning($"The sensor {sensor.GetName()} is using non-trivial mapping and " +
"the attached trainer doesn't support compression mapping. " +
"Switching to uncompressed observations.");
s_HaveWarnedTrainerCapabilitiesMapping = true;
}
compressionType = SensorCompressionType.None;
}

"return SensorCompressionType.None from GetCompressionType()."
);
}
var compressibleSensor = sensor as ISparseChannelSensor;
if (compressibleSensor != null)
{
observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
observationProto.Shape.AddRange(shape);
return observationProto;

return new UnityRLCapabilities
{
BaseRLCapabilities = proto.BaseRLCapabilities,
ConcatenatedPngObservations = proto.ConcatenatedPngObservations
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
};
}

{
BaseRLCapabilities = rlCaps.BaseRLCapabilities,
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
}
internal static bool IsTrivialMapping(ISensor sensor)
{
var compressibleSensor = sensor as ISparseChannelSensor;
if (compressibleSensor is null)
{
return true;
}
var mapping = compressibleSensor.GetCompressedChannelMapping();
if (mapping == null)
{
return true;
}
// check if mapping equals zero mapping
if (mapping.Length == 3 && mapping.All(m => m == 0))
{
return true;
}
// check if mapping equals identity mapping
for (var i = 0; i < mapping.Length; i++)
{
if (mapping[i] != i)
{
return false;
}
}
return true;
}
}
}

4
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


{
public bool BaseRLCapabilities;
public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true)
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
CompressedChannelMapping = compressedChannelMapping;
}
/// <summary>

40
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiWwoYVW5pdHlSTENh",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAhCJaoCIlVuaXR5",
"Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
"c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
}));
}
#endregion

public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
baseRLCapabilities_ = other.baseRLCapabilities_;
concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "compressedChannelMapping" field.</summary>
public const int CompressedChannelMappingFieldNumber = 3;
private bool compressedChannelMapping_;
/// <summary>
/// compression mapping for stacking compressed observations.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool CompressedChannelMapping {
get { return compressedChannelMapping_; }
set {
compressedChannelMapping_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

}
if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

if (ConcatenatedPngObservations != false) {
output.WriteRawTag(16);
output.WriteBool(ConcatenatedPngObservations);
}
if (CompressedChannelMapping != false) {
output.WriteRawTag(24);
output.WriteBool(CompressedChannelMapping);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

if (ConcatenatedPngObservations != false) {
size += 1 + 1;
}
if (CompressedChannelMapping != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

if (other.ConcatenatedPngObservations != false) {
ConcatenatedPngObservations = other.ConcatenatedPngObservations;
}
if (other.CompressedChannelMapping != false) {
CompressedChannelMapping = other.CompressedChannelMapping;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 16: {
ConcatenatedPngObservations = input.ReadBool();
break;
}
case 24: {
CompressedChannelMapping = input.ReadBool();
break;
}
}

34
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyL5AQoQT2JzZXJ2YXRp",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKdAgoQT2JzZXJ2YXRp",
"RmxvYXREYXRhSAAaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2Jz",
"ZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05F",
"EAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9i",
"amVjdHNiBnByb3RvMw=="));
"RmxvYXREYXRhSAASIgoaY29tcHJlc3NlZF9jaGFubmVsX21hcHBpbmcYBSAD",
"KAUaGQoJRmxvYXREYXRhEgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25f",
"ZGF0YSopChRDb21wcmVzc2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5H",
"EAFCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
"b3RvMw=="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion

public ObservationProto(ObservationProto other) : this() {
shape_ = other.shape_.Clone();
compressionType_ = other.compressionType_;
compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

}
}
/// <summary>Field number for the "compressed_channel_mapping" field.</summary>
public const int CompressedChannelMappingFieldNumber = 5;
private static readonly pb::FieldCodec<int> _repeated_compressedChannelMapping_codec
= pb::FieldCodec.ForInt32(42);
private readonly pbc::RepeatedField<int> compressedChannelMapping_ = new pbc::RepeatedField<int>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> CompressedChannelMapping {
get { return compressedChannelMapping_; }
}
private object observationData_;
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
public enum ObservationDataOneofCase {

if (CompressionType != other.CompressionType) return false;
if (CompressedData != other.CompressedData) return false;
if (!object.Equals(FloatData, other.FloatData)) return false;
if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (CompressionType != 0) hash ^= CompressionType.GetHashCode();
if (observationDataCase_ == ObservationDataOneofCase.CompressedData) hash ^= CompressedData.GetHashCode();
if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
hash ^= compressedChannelMapping_.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();

output.WriteRawTag(34);
output.WriteMessage(FloatData);
}
compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

if (observationDataCase_ == ObservationDataOneofCase.FloatData) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(FloatData);
}
size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

if (other.CompressionType != 0) {
CompressionType = other.CompressionType;
}
compressedChannelMapping_.Add(other.compressedChannelMapping_);
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

}
input.ReadMessage(subBuilder);
FloatData = subBuilder;
break;
}
case 42:
case 40: {
compressedChannelMapping_.AddEntriesFrom(input, _repeated_compressedChannelMapping_codec);
break;
}
}

54
com.unity.ml-agents/Runtime/SensorHelper.cs


using UnityEngine;
using Unity.Barracuda;
namespace Unity.MLAgents.Sensors
{

}
}
errorMessage = null;
return true;
}
public static bool CompareObservation(ISensor sensor, float[,,] expected, out string errorMessage)
{
var tensorShape = new TensorShape(0, expected.GetLength(0), expected.GetLength(1), expected.GetLength(2));
var numExpected = tensorShape.height * tensorShape.width * tensorShape.channels;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "Error setting output buffer.";
return false;
}
}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
return false;
}
}
sensor.Write(writer);
for (var h = 0; h < tensorShape.height; h++)
{
for (var w = 0; w < tensorShape.width; w++)
{
for (var c = 0; c < tensorShape.channels; c++)
{
if (expected[h, w, c] != output[tensorShape.Index(0, h, w, c)])
{
errorMessage = $"Expected and actual differed in position [{h}, {w}, {c}]. " +
"Expected: {expected[h, w, c]} Actual: {output[tensorShape.Index(0, h, w, c)]} ";
return false;
}
}
}
}
errorMessage = null;
return true;
}

28
com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs


set { m_Grayscale = value; }
}
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
int m_ObservationStacks = 1;
[HideInInspector, SerializeField, FormerlySerializedAs("compression")]
SensorCompressionType m_Compression = SensorCompressionType.PNG;

}
/// <summary>
/// Whether to stack previous observations. Using 1 means no previous observations.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int ObservationStacks
{
get { return m_ObservationStacks; }
set { m_ObservationStacks = value; }
}
/// <summary>
/// Creates the <see cref="CameraSensor"/>
/// </summary>
/// <returns>The created <see cref="CameraSensor"/> object for this component.</returns>

if (ObservationStacks != 1)
{
return new StackingSensor(m_Sensor, ObservationStacks);
}
return m_Sensor;
}

/// <returns>The observation shape of the associated <see cref="CameraSensor"/> object.</returns>
public override int[] GetObservationShape()
{
return CameraSensor.GenerateShape(m_Width, m_Height, Grayscale);
var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
var cameraSensorshape = CameraSensor.GenerateShape(m_Width, m_Height, Grayscale);
if (stacks > 1)
{
cameraSensorshape[cameraSensorshape.Length - 1] *= stacks;
}
return cameraSensorshape;
}
/// <summary>

28
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs


set { m_Grayscale = value; }
}
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of frames that will be stacked before being fed to the neural network.")]
int m_ObservationStacks = 1;
[HideInInspector, SerializeField, FormerlySerializedAs("compression")]
SensorCompressionType m_Compression = SensorCompressionType.PNG;

set { m_Compression = value; UpdateSensor(); }
}
/// <summary>
/// Whether to stack previous observations. Using 1 means no previous observations.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int ObservationStacks
{
get { return m_ObservationStacks; }
set { m_ObservationStacks = value; }
}
if (ObservationStacks != 1)
{
return new StackingSensor(m_Sensor, ObservationStacks);
}
return m_Sensor;
}

var width = RenderTexture != null ? RenderTexture.width : 0;
var height = RenderTexture != null ? RenderTexture.height : 0;
var observationShape = new[] { height, width, Grayscale ? 1 : 3 };
return new[] { height, width, Grayscale ? 1 : 3 };
var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
if (stacks > 1)
{
observationShape[2] *= stacks;
}
return observationShape;
}
/// <summary>

190
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


using System;
using System.Linq;
using System.Runtime.CompilerServices;
using UnityEngine;
using Unity.Barracuda;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
namespace Unity.MLAgents.Sensors
{

/// For example, 4 stacked sets of observations would be output like
/// | t = now - 3 | t = now -3 | t = now - 2 | t = now |
/// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
///
/// Currently, compressed and multidimensional observations are not supported.
/// Currently, observations are stacked on the last dimension.
public class StackingSensor : ISensor
public class StackingSensor : ISparseChannelSensor
{
/// <summary>
/// The wrapped sensor.

string m_Name;
int[] m_Shape;
int[] m_WrappedShape;
byte[][] m_StackedCompressedObservations;
byte[] m_EmptyCompressedObservation;
int[] m_CompressionMapping;
TensorShape m_tensorShape;
/// <summary>
/// Initializes the sensor.
/// </summary>

m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
if (wrapped.GetCompressionType() != SensorCompressionType.None)
m_WrappedShape = wrapped.GetObservationShape();
m_Shape = new int[m_WrappedShape.Length];
m_UnstackedObservationSize = wrapped.ObservationSize();
for (int d = 0; d < m_WrappedShape.Length; d++)
throw new UnityAgentsException("StackingSensor doesn't support compressed observations.'");
m_Shape[d] = m_WrappedShape[d];
var shape = wrapped.GetObservationShape();
if (shape.Length != 1)
// TODO support arbitrary stacking dimension
m_Shape[m_Shape.Length - 1] *= numStackedObservations;
// Initialize uncompressed buffer anyway in case python trainer does not
// support the compression mapping and has to fall back to uncompressed obs.
m_StackedObservations = new float[numStackedObservations][];
for (var i = 0; i < numStackedObservations; i++)
throw new UnityAgentsException("Only 1-D observations are supported by StackingSensor");
m_StackedObservations[i] = new float[m_UnstackedObservationSize];
m_Shape = new int[shape.Length];
m_UnstackedObservationSize = wrapped.ObservationSize();
for (int d = 0; d < shape.Length; d++)
if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
m_Shape[d] = shape[d];
m_StackedCompressedObservations = new byte[numStackedObservations][];
m_EmptyCompressedObservation = CreateEmptyPNG();
for (var i = 0; i < numStackedObservations; i++)
{
m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
}
m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
// TODO support arbitrary stacking dimension
m_Shape[0] *= numStackedObservations;
m_StackedObservations = new float[numStackedObservations][];
for (var i = 0; i < numStackedObservations; i++)
if (m_Shape.Length != 1)
m_StackedObservations[i] = new float[m_UnstackedObservationSize];
m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
}
}

// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
var wrappedShape = m_WrappedSensor.GetObservationShape();
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], wrappedShape, 0);
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0);
for (var i = 0; i < m_NumStackedObservations; i++)
if (m_WrappedShape.Length == 1)
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
writer.AddRange(m_StackedObservations[obsIndex], numWritten);
numWritten += m_UnstackedObservationSize;
for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
writer.AddRange(m_StackedObservations[obsIndex], numWritten);
numWritten += m_UnstackedObservationSize;
}
}
else
{
for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
for (var h = 0; h < m_WrappedShape[0]; h++)
{
for (var w = 0; w < m_WrappedShape[1]; w++)
{
for (var c = 0; c < m_WrappedShape[2]; c++)
{
writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
}
}
}
}
numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations;
}
return numWritten;

{
Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length);
}
if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{
m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
}
}
}
/// <inheritdoc/>

}
/// <inheritdoc/>
public virtual byte[] GetCompressedObservation()
public byte[] GetCompressedObservation()
return null;
var compressed = m_WrappedSensor.GetCompressedObservation();
m_StackedCompressedObservations[m_CurrentIndex] = compressed;
int bytesLength = 0;
foreach (byte[] compressedObs in m_StackedCompressedObservations)
{
bytesLength += compressedObs.Length;
}
byte[] outputBytes = new byte[bytesLength];
int offset = 0;
for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
Buffer.BlockCopy(m_StackedCompressedObservations[obsIndex],
0, outputBytes, offset, m_StackedCompressedObservations[obsIndex].Length);
offset += m_StackedCompressedObservations[obsIndex].Length;
}
return outputBytes;
}
public int[] GetCompressedChannelMapping()
{
return m_CompressionMapping;
public virtual SensorCompressionType GetCompressionType()
public SensorCompressionType GetCompressionType()
return SensorCompressionType.None;
return m_WrappedSensor.GetCompressionType();
// TODO support stacked compressed observations (byte stream)
/// <summary>
/// Create Empty PNG for initializing the buffer for stacking.
/// </summary>
internal byte[] CreateEmptyPNG()
{
int height = m_WrappedSensor.GetObservationShape()[0];
int width = m_WrappedSensor.GetObservationShape()[1];
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
return texture2D.EncodeToPNG();
}
/// <summary>
/// Constrct stacked CompressedChannelMapping.
/// </summary>
internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor)
{
// Get CompressedChannelMapping of the wrapped sensor. If the
// wrapped sensor doesn't have one, use default mapping.
// Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
int[] wrappedMapping = null;
int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2];
var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
if (sparseChannelSensor != null)
{
wrappedMapping = sparseChannelSensor.GetCompressedChannelMapping();
}
if (wrappedMapping == null)
{
if (wrappedNumChannel == 1)
{
wrappedMapping = new int[] { 0, 0, 0 };
}
else
{
wrappedMapping = Enumerable.Range(0, wrappedNumChannel).ToArray();
}
}
// Construct stacked mapping using the mapping of wrapped sensor.
// First pad the wrapped mapping to multiple of 3, then repeat
// and add offset to each copy to form the stacked mapping.
int paddedMapLength = (wrappedMapping.Length + 2) / 3 * 3;
var compressionMapping = new int[paddedMapLength * m_NumStackedObservations];
for (var i = 0; i < m_NumStackedObservations; i++)
{
var offset = wrappedNumChannel * i;
for (var j = 0; j < paddedMapLength; j++)
{
if (j < wrappedMapping.Length)
{
compressionMapping[j + paddedMapLength * i] = wrappedMapping[j] >= 0 ? wrappedMapping[j] + offset : -1;
}
else
{
compressionMapping[j + paddedMapLength * i] = -1;
}
}
}
return compressionMapping;
}
}
}

31
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


}
}
class DummySparseChannelSensor : DummySensor, ISparseChannelSensor
{
public int[] Mapping;
internal DummySparseChannelSensor()
{
}
public int[] GetCompressedChannelMapping()
{
return Mapping;
}
}
[Test]
public void TestGetObservationProtoCapabilities()
{

}
}
[Test]
public void TestIsTrivialMapping()
{
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(new DummySensor()), true);
var sparseChannelSensor = new DummySparseChannelSensor();
sparseChannelSensor.Mapping = null;
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
sparseChannelSensor.Mapping = new int[] { 0, 0, 0 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
sparseChannelSensor.Mapping = new int[] { 0, 1, 2, 3, 4 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
sparseChannelSensor.Mapping = new int[] { 1, 2, 3, 4, -1, -1 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
sparseChannelSensor.Mapping = new int[] { 0, 0, 0, 1, 1, 1 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
}
}
}

156
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


using NUnit.Framework;
using System;
using System.Linq;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests

}
[Test]
public void TestStacking()
public void TestVectorStacking()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);

}
[Test]
public void TestStackingReset()
public void TestVectorStackingReset()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);

sensor.Reset();
wrapped.AddObservation(new[] { 5f, 6f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 5f, 6f });
}
class Dummy3DSensor : ISparseChannelSensor
{
public SensorCompressionType CompressionType = SensorCompressionType.PNG;
public int[] Mapping;
public int[] Shape;
public float[,,] CurrentObservation;
internal Dummy3DSensor()
{
}
public int[] GetObservationShape()
{
return Shape;
}
public int Write(ObservationWriter writer)
{
for (var h = 0; h < Shape[0]; h++)
{
for (var w = 0; w < Shape[1]; w++)
{
for (var c = 0; c < Shape[2]; c++)
{
writer[h, w, c] = CurrentObservation[h, w, c];
}
}
}
return Shape[0] * Shape[1] * Shape[2];
}
public byte[] GetCompressedObservation()
{
var writer = new ObservationWriter();
var flattenedObservation = new float[Shape[0] * Shape[1] * Shape[2]];
writer.SetTarget(flattenedObservation, Shape, 0);
Write(writer);
byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z);
return bytes;
}
public void Update() { }
public void Reset() { }
public SensorCompressionType GetCompressionType()
{
return CompressionType;
}
public string GetName()
{
return "Dummy";
}
public int[] GetCompressedChannelMapping()
{
return Mapping;
}
}
[Test]
public void TestStackingMapping()
{
// Test grayscale stacked mapping with CameraSensor
var cameraSensor = new CameraSensor(new Camera(), 64, 64,
true, "grayscaleCamera", SensorCompressionType.PNG);
var stackedCameraSensor = new StackingSensor(cameraSensor, 2);
Assert.AreEqual(stackedCameraSensor.GetCompressedChannelMapping(), new[] { 0, 0, 0, 1, 1, 1 });
// Test RGB stacked mapping with RenderTextureSensor
var renderTextureSensor = new RenderTextureSensor(new RenderTexture(24, 16, 0),
false, "renderTexture", SensorCompressionType.PNG);
var stackedRenderTextureSensor = new StackingSensor(renderTextureSensor, 2);
Assert.AreEqual(stackedRenderTextureSensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, 4, 5 });
// Test mapping with number of layers not being multiple of 3
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new int[] { 2, 2, 4 };
dummySensor.Mapping = new int[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
// Test mapping with dummy layers that should be dropped
var paddedDummySensor = new Dummy3DSensor();
paddedDummySensor.Shape = new int[] { 2, 2, 4 };
paddedDummySensor.Mapping = new int[] { 0, 1, 2, 3, -1, -1 };
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
}
[Test]
public void Test3DStacking()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new int[] { 2, 1, 2 };
var sensor = new StackingSensor(wrapped, 2);
// Check the stacking is on the last dimension
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f } }, { { 3f, 4f } } };
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 1f, 2f } }, { { 0f, 0f, 3f, 4f } } });
sensor.Update();
wrapped.CurrentObservation = new[, ,] { { { 5f, 6f } }, { { 7f, 8f } } };
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 1f, 2f, 5f, 6f } }, { { 3f, 4f, 7f, 8f } } });
sensor.Update();
wrapped.CurrentObservation = new[, ,] { { { 9f, 10f } }, { { 11f, 12f } } };
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } });
// Check that if we don't call Update(), the same observations are produced
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } });
// Test reset
sensor.Reset();
wrapped.CurrentObservation = new[, ,] { { { 13f, 14f } }, { { 15f, 16f } } };
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 13f, 14f } }, { { 0f, 0f, 15f, 16f } } });
}
[Test]
public void TestStackedGetCompressedObservation()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new int[] { 1, 1, 3 };
var sensor = new StackingSensor(wrapped, 2);
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } };
var expected1 = sensor.CreateEmptyPNG();
expected1 = expected1.Concat(Array.ConvertAll(new[] { 1f, 2f, 3f }, (z) => (byte)z)).ToArray();
Assert.AreEqual(sensor.GetCompressedObservation(), expected1);
sensor.Update();
wrapped.CurrentObservation = new[, ,] { { { 4f, 5f, 6f } } };
var expected2 = Array.ConvertAll(new[] { 1f, 2f, 3f, 4f, 5f, 6f }, (z) => (byte)z);
Assert.AreEqual(sensor.GetCompressedObservation(), expected2);
sensor.Update();
wrapped.CurrentObservation = new[, ,] { { { 7f, 8f, 9f } } };
var expected3 = Array.ConvertAll(new[] { 4f, 5f, 6f, 7f, 8f, 9f }, (z) => (byte)z);
Assert.AreEqual(sensor.GetCompressedObservation(), expected3);
// Test reset
sensor.Reset();
wrapped.CurrentObservation = new[, ,] { { { 10f, 11f, 12f } } };
var expected4 = sensor.CreateEmptyPNG();
expected4 = expected4.Concat(Array.ConvertAll(new[] { 10f, 11f, 12f }, (z) => (byte)z)).ToArray();
Assert.AreEqual(sensor.GetCompressedObservation(), expected4);
}
}
}

10
com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs


namespace Unity.MLAgents.Tests
{
public static class SensorTestHelper
{
public static void CompareObservation(ISensor sensor, float[] expected)
{
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
}
public class VectorSensorTests
{
[Test]

11
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py


name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"[\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"}\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='compressedChannelMapping', full_name='communicator_objects.UnityRLCapabilitiesProto.compressedChannelMapping', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=79,
serialized_end=170,
serialized_end=204,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi


DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
baseRLCapabilities = ... # type: builtin___bool
concatenatedPngObservations = ... # type: builtin___bool
compressedChannelMapping = ... # type: builtin___bool
compressedChannelMapping : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"concatenatedPngObservations"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations"]) -> None: ...

19
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py


name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xf9\x01\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x9d\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(

],
containing_type=None,
options=None,
serialized_start=330,
serialized_end=371,
serialized_start=366,
serialized_end=407,
)
_sym_db.RegisterEnumDescriptor(_COMPRESSIONTYPEPROTO)

extension_ranges=[],
oneofs=[
],
serialized_start=283,
serialized_end=308,
serialized_start=319,
serialized_end=344,
)
_OBSERVATIONPROTO = _descriptor.Descriptor(

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='compressed_channel_mapping', full_name='communicator_objects.ObservationProto.compressed_channel_mapping', index=4,
number=5, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

index=0, containing_type=None, fields=[]),
],
serialized_start=79,
serialized_end=328,
serialized_end=364,
)
_OBSERVATIONPROTO_FLOATDATA.containing_type = _OBSERVATIONPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi


shape = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
compression_type = ... # type: CompressionTypeProto
compressed_data = ... # type: builtin___bytes
compressed_channel_mapping = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
@property
def float_data(self) -> ObservationProto.FloatData: ...

compression_type : typing___Optional[CompressionTypeProto] = None,
compressed_data : typing___Optional[builtin___bytes] = None,
float_data : typing___Optional[ObservationProto.FloatData] = None,
compressed_channel_mapping : typing___Optional[typing___Iterable[builtin___int]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> ObservationProto: ...

def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"float_data",u"observation_data",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ...

4
ml-agents-envs/mlagents_envs/environment.py


# Revision history:
# * 1.0.0 - initial version
# * 1.1.0 - support concatenated PNGs for compressed observations.
API_VERSION = "1.1.0"
# * 1.2.0 - support compression mapping for stacked compressed observations.
API_VERSION = "1.2.0"
# Default port that the editor listens on. If an environment executable
# isn't specified, this port will be used.

capabilities = UnityRLCapabilitiesProto()
capabilities.baseRLCapabilities = True
capabilities.concatenatedPngObservations = True
capabilities.compressedChannelMapping = True
return capabilities
@staticmethod

81
ml-agents-envs/mlagents_envs/rpc_utils.py


@timed
def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray:
def process_pixels(
image_bytes: bytes, expected_channels: int, mappings: Optional[List[int]] = None
) -> np.ndarray:
"""
Converts byte array observation image into numpy array, re-sizes it,
and optionally converts it to grey scale

"""
image_fp = OffsetBytesIO(image_bytes)
if expected_channels == 1:
# Convert to grayscale
with hierarchical_timer("image_decompress"):
image = Image.open(image_fp)
# Normally Image loads lazily, load() forces it to do loading in the timer scope.
image.load()
s = np.array(image, dtype=np.float32) / 255.0
s = np.mean(s, axis=2)
s = np.reshape(s, [s.shape[0], s.shape[1], 1])
return s
# Normally Image loads lazily, load() forces it to do loading in the timer scope.
image.load()
image_arrays.append(np.array(image, dtype=np.float32) / 255.0)

# Didn't find the header, so must be at the end.
break
img = np.concatenate(image_arrays, axis=2)
# We can drop additional channels since they may need to be added to include
# numbers of observation channels not divisible by 3.
actual_channels = list(img.shape)[2]
if actual_channels > expected_channels:
img = img[..., 0:expected_channels]
if mappings is not None and len(mappings) > 0:
return _process_images_mapping(image_arrays, mappings)
else:
return _process_images_num_channels(image_arrays, expected_channels)
def _process_images_mapping(image_arrays, mappings):
"""
Helper function for processing decompressed images with compressed channel mappings.
"""
image_arrays = np.concatenate(image_arrays, axis=2).transpose((2, 0, 1))
if len(mappings) != len(image_arrays):
raise UnityObservationException(
f"Compressed observation and its mapping had different number of channels - "
f"observation had {len(image_arrays)} channels but its mapping had {len(mappings)} channels"
)
if len({m for m in mappings if m > -1}) != max(mappings) + 1:
raise UnityObservationException(
f"Invalid Compressed Channel Mapping: the mapping {mappings} does not have the correct format."
)
if max(mappings) >= len(image_arrays):
raise UnityObservationException(
f"Invalid Compressed Channel Mapping: the mapping has index larger than the total "
f"number of channels in observation - mapping index {max(mappings)} is"
f"invalid for input observation with {len(image_arrays)} channels."
)
processed_image_arrays: List[np.array] = [[] for _ in range(max(mappings) + 1)]
for mapping_idx, img in zip(mappings, image_arrays):
if mapping_idx > -1:
processed_image_arrays[mapping_idx].append(img)
for i, img_array in enumerate(processed_image_arrays):
processed_image_arrays[i] = np.mean(img_array, axis=0)
img = np.stack(processed_image_arrays, axis=2)
return img
def _process_images_num_channels(image_arrays, expected_channels):
"""
Helper function for processing decompressed images with number of expected channels.
This is for old API without mapping provided. Use the first n channel, n=expected_channels.
"""
if expected_channels == 1:
# Convert to grayscale
img = np.mean(image_arrays[0], axis=2)
img = np.reshape(img, [img.shape[0], img.shape[1], 1])
else:
img = np.concatenate(image_arrays, axis=2)
# We can drop additional channels since they may need to be added to include
# numbers of observation channels not divisible by 3.
actual_channels = list(img.shape)[2]
if actual_channels > expected_channels:
img = img[..., 0:expected_channels]
return img

img = np.reshape(img, obs.shape)
return img
else:
img = process_pixels(obs.compressed_data, expected_channels)
img = process_pixels(
obs.compressed_data, expected_channels, list(obs.compressed_channel_mapping)
)
# Compare decompressed image size to observation shape and make sure they match
if list(obs.shape) != list(img.shape):
raise UnityObservationException(

76
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


return bytes_out
def generate_compressed_proto_obs(in_array: np.ndarray) -> ObservationProto:
# test helper function for old C# API (no compressed channel mapping)
def generate_compressed_proto_obs(
in_array: np.ndarray, grayscale: bool = False
) -> ObservationProto:
obs_proto.shape.extend(in_array.shape)
if grayscale:
# grayscale flag is only used for old API without mapping
expected_shape = [in_array.shape[0], in_array.shape[1], 1]
obs_proto.shape.extend(expected_shape)
else:
obs_proto.shape.extend(in_array.shape)
return obs_proto
# test helper function for new C# API (with compressed channel mapping)
def generate_compressed_proto_obs_with_mapping(
in_array: np.ndarray, mapping: List[int]
) -> ObservationProto:
obs_proto = ObservationProto()
obs_proto.compressed_data = generate_compressed_data(in_array)
obs_proto.compression_type = PNG
if mapping is not None:
obs_proto.compressed_channel_mapping.extend(mapping)
expected_shape = [
in_array.shape[0],
in_array.shape[1],
len({m for m in mapping if m >= 0}),
]
obs_proto.shape.extend(expected_shape)
else:
obs_proto.shape.extend(in_array.shape)
return obs_proto

in_array_1 = np.random.rand(128, 64, 3)
proto_obs_1 = generate_compressed_proto_obs(in_array_1)
in_array_2 = np.random.rand(128, 64, 3)
proto_obs_2 = generate_uncompressed_proto_obs(in_array_2)
in_array_2_mapping = [0, 1, 2]
proto_obs_2 = generate_compressed_proto_obs_with_mapping(
in_array_2, in_array_2_mapping
)
ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap2 = AgentInfoProto()

assert list(arr.shape) == [2, 128, 64, 3]
assert np.allclose(arr[0, :, :, :], in_array_1, atol=0.01)
assert np.allclose(arr[1, :, :, :], in_array_2, atol=0.01)
def test_process_visual_observation_grayscale():
in_array_1 = np.random.rand(128, 64, 3)
proto_obs_1 = generate_compressed_proto_obs(in_array_1, grayscale=True)
expected_out_array_1 = np.mean(in_array_1, axis=2, keepdims=True)
in_array_2 = np.random.rand(128, 64, 3)
in_array_2_mapping = [0, 0, 0]
proto_obs_2 = generate_compressed_proto_obs_with_mapping(
in_array_2, in_array_2_mapping
)
expected_out_array_2 = np.mean(in_array_2, axis=2, keepdims=True)
ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap2 = AgentInfoProto()
ap2.observations.extend([proto_obs_2])
ap_list = [ap1, ap2]
arr = _process_visual_observation(0, (128, 64, 1), ap_list)
assert list(arr.shape) == [2, 128, 64, 1]
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)
assert np.allclose(arr[1, :, :, :], expected_out_array_2, atol=0.01)
def test_process_visual_observation_padded_channels():
in_array_1 = np.random.rand(128, 64, 12)
in_array_1_mapping = [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
proto_obs_1 = generate_compressed_proto_obs_with_mapping(
in_array_1, in_array_1_mapping
)
expected_out_array_1 = np.take(in_array_1, [0, 1, 2, 3, 6, 7, 8, 9], axis=2)
ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap_list = [ap1]
arr = _process_visual_observation(0, (128, 64, 8), ap_list)
assert list(arr.shape) == [1, 128, 64, 8]
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)
def test_process_visual_observation_bad_shape():

3
protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto


// concatenated PNG files for compressed visual observations with >3 channels.
bool concatenatedPngObservations = 2;
// compression mapping for stacking compressed observations.
bool compressedChannelMapping = 3;
}

1
protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto


bytes compressed_data = 3;
FloatData float_data = 4;
}
repeated int32 compressed_channel_mapping = 5;
}

20
com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs


namespace Unity.MLAgents.Sensors
{
/// <summary>
/// Sensor interface for sparse channel sensor which requires a compressed channel mapping.
/// </summary>
public interface ISparseChannelSensor : ISensor
{
/// <summary>
/// Returns the mapping of the channels in compressed data to the actual channel after decompression.
/// The mapping is a list of interger index with the same length as
/// the number of output observation layers (channels), including padding if there's any.
/// Each index indicates the actual channel the layer will go into.
/// Layers with the same index will be averaged, and layers with negative index will be dropped.
/// For example, mapping for CameraSensor using grayscale and stacking of two: [0, 0, 0, 1, 1, 1]
/// Mapping for GridSensor of 4 channels and stacking of two: [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
/// </summary>
/// <returns>Mapping of the compressed data</returns>
int[] GetCompressedChannelMapping();
}
}

11
com.unity.ml-agents/Runtime/Sensors/ISparseChannelSensor.cs.meta


fileFormatVersion: 2
guid: 63bb76c1e31c24fa5b4a384ea0edbfb0
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

23
com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs


using NUnit.Framework;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
public static class SensorTestHelper
{
public static void CompareObservation(ISensor sensor, float[] expected)
{
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
public static void CompareObservation(ISensor sensor, float[,,] expected)
{
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
}
}

11
com.unity.ml-agents/Tests/Editor/Sensor/SensorTestHelper.cs.meta


fileFormatVersion: 2
guid: e769354f8bd404ca180d7cd7302a5d61
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存