浏览代码

[MLA-1634] Add ObservationSpec and update ISensor interfaces (#5127)

/goal-conditioning/sensors-3-pytest-fix
Christopher Goy 4 年前
当前提交
113c1bca
共有 55 个文件被更改,包括 1009 次插入374 次删除
  1. 2
      DevProject/Packages/manifest.json
  2. 2
      DevProject/Packages/packages-lock.json
  3. 18
      DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json
  4. 4
      DevProject/ProjectSettings/ProjectVersion.txt
  5. 4
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
  6. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  7. 5
      Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta
  8. 8
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
  9. 12
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  10. 30
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  11. 10
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  12. 12
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  13. 8
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs
  14. 6
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs
  15. 10
      com.unity.ml-agents/CHANGELOG.md
  16. 1
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  17. 7
      com.unity.ml-agents/Runtime/Analytics/Events.cs
  18. 55
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  19. 26
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  20. 9
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  21. 2
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  22. 4
      com.unity.ml-agents/Runtime/SensorHelper.cs
  23. 24
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  24. 32
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  25. 70
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  26. 13
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  27. 8
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  28. 12
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  29. 8
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  30. 21
      com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
  31. 48
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  32. 10
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  33. 23
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  34. 19
      com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
  35. 4
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  36. 6
      com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs
  37. 3
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  38. 13
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  39. 2
      com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
  40. 20
      com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs
  41. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
  42. 27
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  43. 30
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  44. 2
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
  45. 241
      com.unity.ml-agents/Runtime/InplaceArray.cs
  46. 3
      com.unity.ml-agents/Runtime/InplaceArray.cs.meta
  47. 140
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  48. 3
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta
  49. 192
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs
  50. 78
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs
  51. 3
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta
  52. 31
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs
  53. 11
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta
  54. 47
      com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs
  55. 0
      /com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta

2
DevProject/Packages/manifest.json


"com.unity.package-manager-doctools": "1.7.0-preview",
"com.unity.package-validation-suite": "0.19.0-preview",
"com.unity.purchasing": "2.2.1",
"com.unity.test-framework": "1.1.20",
"com.unity.test-framework": "1.1.22",
"com.unity.test-framework.performance": "2.2.0-preview",
"com.unity.testtools.codecoverage": "1.0.0-pre.3",
"com.unity.textmeshpro": "2.0.1",

2
DevProject/Packages/packages-lock.json


"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
},
"com.unity.test-framework": {
"version": "1.1.20",
"version": "1.1.22",
"depth": 0,
"source": "registry",
"dependencies": {

18
DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json


"m_Name": "Settings",
"m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json",
"m_Dictionary": {
"m_DictionaryValues": []
"m_DictionaryValues": [
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "Path",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "HistoryPath",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "IncludeAssemblies",
"value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}"
}
]
}
}

4
DevProject/ProjectSettings/ProjectVersion.txt


m_EditorVersion: 2019.4.19f1
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec)
m_EditorVersion: 2019.4.20f1
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa)

4
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs


}
/// <inheritdoc/>
public override int[] GetObservationShape()
public override ObservationSpec GetObservationSpec()
return new[] { BasicController.k_Extents };
return ObservationSpec.Vector(BasicController.k_Extents);
}
/// <inheritdoc/>

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


public abstract void WriteObservation(float[] output);
/// <inheritdoc/>
public abstract int[] GetObservationShape();
public abstract ObservationSpec GetObservationSpec();
/// <inheritdoc/>
public abstract string GetName();

5
Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta


fileFormatVersion: 2
guid: 8cd4584c2f2cb4c5fb51675d364e10ec
ScriptedImporter:
internalIDToNameTable: []
fileIDToRecycleName:
11400000: main obj
11400002: model data
serializedVersion: 2
userData:
assetBundleName:
assetBundleVariant:

8
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs


{
Texture2D m_Texture;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var width = texture.width;
var height = texture.height;
m_Name = name;
m_Shape = new[] { height, width, 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

m_NumSpecialTypes = board.NumSpecialTypes;
m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

30
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


protected bool Initialized = false;
/// <summary>
/// Array holding the dimensions of the resulting tensor
/// Cached ObservationSpec
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
//
// Debug Parameters

// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();

CellActivity[i] = DebugDefaultColor;
}
}
}
/// <summary>Gets the shape of the grid observation</summary>
/// <returns>integer array shape of the grid observation</returns>
public int[] GetFloatObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}
/// <inheritdoc/>

/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
public ObservationSpec GetObservationSpec()
{
// Lazy update
var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}
/// <inheritdoc/>
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
/// <inheritdoc/>

10
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_SensorName;
PoseExtractor m_PoseExtractor;

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#if UNITY_2020_1_OR_NEWER

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

8
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

6
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
}
}

10
com.unity.ml-agents/CHANGELOG.md


## [Unreleased]
### Major Changes
#### com.unity.ml-agents (C#)
======
- Several breaking interface changes were made. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details.
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060)
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details. (#5060)
- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127)
#### ml-agents / ml-agents-envs / gym-unity (Python)
### Minor Changes

1
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


/// <summary>
/// Set whether or not the action index for the given branch is allowed.
/// </summary>
/// <remarks>
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond

7
com.unity.ml-agents/Runtime/Analytics/Events.cs


public static EventObservationSpec FromSensor(ISensor sensor)
{
var shape = sensor.GetObservationShape();
var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
var dimProps = obsSpec.DimensionProperties;
dimInfos[i].Flags = dimProps != null ? (int)dimProps[i] : 0;
dimInfos[i].Flags = (int)dimProps[i];
}
var builtInSensorType =

55
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


/// <returns></returns>
public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
ObservationProto observationProto = null;
var compressionType = sensor.GetCompressionType();
// Check capabilities if we need to concatenate PNGs

floatDataProto.Data.Add(0.0f);
}
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0);
sensor.Write(observationWriter);
observationProto = new ObservationProto

observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
// Add the dimension properties if any to the observationProto
var dimensionPropertySensor = sensor as IDimensionPropertiesSensor;
if (dimensionPropertySensor != null)
// Add the dimension properties to the observationProto
var dimensionProperties = obsSpec.DimensionProperties;
for (int i = 0; i < dimensionProperties.Length; i++)
{
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
}
// Checking trainer compatibility with variable length observations
if (dimensionProperties == new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None))
var dimensionProperties = dimensionPropertySensor.GetDimensionProperties();
for (int i = 0; i < dimensionProperties.Length; i++)
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
// Checking trainer compatibility with variable length observations
if (dimensionProperties.Length == 2)
{
if (dimensionProperties[0] == DimensionProperty.VariableSize &&
dimensionProperties[1] == DimensionProperty.None)
{
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
{
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
}
}
}
}
for (var i = 0; i < shape.Length; i++)
{
observationProto.Shape.Add(shape[i]);
observationProto.Shape.AddRange(shape);
var sensorName = sensor.GetName();
if (!string.IsNullOrEmpty(sensorName))
{

// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;
if (typeSensor != null)
{
observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
}
else
{
observationProto.ObservationType = ObservationTypeProto.Default;
}
observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType;
return observationProto;
}

26
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sensor = sensors[sensorIndex];
if (sensor.GetObservationShape().Length == 3)
if (sensor.GetObservationSpec().Shape.Length == 3)
{
if (!tensorsNames.Contains(
TensorNames.GetVisualObservationName(visObsIndex)))

}
visObsIndex++;
}
if (sensor.GetObservationShape().Length == 2)
if (sensor.GetObservationSpec().Shape.Length == 2)
{
if (!tensorsNames.Contains(
TensorNames.GetObservationName(sensorIndex)))

static FailedCheck CheckVisualObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];

static FailedCheck CheckRankTwoObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;

for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
if (sens.GetObservationSpec().Shape.Length == 3)
{
tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] =

if (sens.GetObservationShape().Length == 2)
if (sens.GetObservationSpec().Shape.Length == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);

var totalVectorSensorSize = 0;
foreach (var sens in sensors)
{
if ((sens.GetObservationShape().Length == 1))
if ((sens.GetObservationSpec().Shape.Length == 1))
totalVectorSensorSize += sens.GetObservationShape()[0];
totalVectorSensorSize += sens.GetObservationSpec().Shape[0];
}
}

foreach (var sensorComp in sensors)
{
if (sensorComp.GetObservationShape().Length == 1)
if (sensorComp.GetObservationSpec().Shape.Length == 1)
var vecSize = sensorComp.GetObservationShape()[0];
var vecSize = sensorComp.GetObservationSpec().Shape[0];
if (sensorSizes.Length == 0)
{
sensorSizes = $"[{vecSize}";

for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
if (sens.GetObservationSpec().Rank == 3)
if (sens.GetObservationShape().Length == 2)
if (sens.GetObservationSpec().Rank == 2)
if (sens.GetObservationShape().Length == 1)
if (sens.GetObservationSpec().Rank == 1)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens);

9
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];
var shape = sensor.GetObservationShape();
var rank = shape.Length;
var rank = sensor.GetObservationSpec().Rank;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)

}
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
{
var sensor = sensors[sensorIndex];
var shape = sensor.GetObservationSpec().Shape;
var rank = shape.Length;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var obsGen = new ObservationGenerator(allocator);

2
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


{
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationShape(), 0);
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0);
sensor.Write(m_ObservationWriter);
}
else

4
com.unity.ml-agents/Runtime/SensorHelper.cs


}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

24
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


/// <summary>
/// A Sensor that allows to observe a variable number of entities.
/// </summary>
public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor
public class BufferSensor : ISensor, IBuiltInSensor
{
private string m_Name;
private int m_MaxNumObs;

static DimensionProperty[] s_DimensionProperties = new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None
};
ObservationSpec m_ObservationSpec;
/// <summary>
/// Creates the BufferSensor.
/// </summary>
/// <param name="maxNumberObs">The maximum number of observations to be appended to this BufferSensor.</param>
/// <param name="obsSize">The size of each observation appended to the BufferSensor.</param>
/// <param name="name">The name of the sensor.</param>
public BufferSensor(int maxNumberObs, int obsSize, string name)
{
m_Name = name;

m_CurrentNumObservables = 0;
}
/// <inheritdoc/>
public int[] GetObservationShape()
{
return new int[] { m_MaxNumObs, m_ObsSize };
m_ObservationSpec = ObservationSpec.VariableLength(m_MaxNumObs, m_ObsSize);
public DimensionProperty[] GetDimensionProperties()
public ObservationSpec GetObservationSpec()
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

32
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


/// <summary>
/// A sensor that wraps a Camera object to generate visual observations for an agent.
/// </summary>
public class CameraSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor
public class CameraSensor : ISensor, IBuiltInSensor
{
Camera m_Camera;
int m_Width;

int[] m_Shape;
private ObservationSpec m_ObservationSpec;
static DimensionProperty[] s_DimensionProperties = new DimensionProperty[] {
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None };
/// <summary>
/// The Camera used for rendering the sensor observations.

m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = GenerateShape(width, height, grayscale);
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_CompressionType = compression;
}

return m_Name;
}
/// <summary>
/// Accessor for the size of the sensor data. Will be h x w x 1 for grayscale and
/// h x w x 3 for color.
/// </summary>
/// <returns>Size of each of the three dimensions.</returns>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <summary>
/// Accessor for the dimension properties of a camera sensor. A camera sensor
/// Has translational equivariance along width and hight and no property along
/// the channels dimension.
/// </summary>
/// <returns></returns>
public DimensionProperty[] GetDimensionProperties()
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

70
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


}
/// <summary>
/// The Dimension property flags of the observations
/// </summary>
[System.Flags]
public enum DimensionProperty
{
/// <summary>
/// No properties specified.
/// </summary>
Unspecified = 0,
/// <summary>
/// No Property of the observation in that dimension. Observation can be processed with
/// fully connected networks.
/// </summary>
None = 1,
/// <summary>
/// Means it is suitable to do a convolution in this dimension.
/// </summary>
TranslationalEquivariance = 2,
/// <summary>
/// Means that there can be a variable number of observations in this dimension.
/// The observations are unordered.
/// </summary>
VariableSize = 4,
}
/// <summary>
/// The ObservationType enum of the Sensor.
/// </summary>
public enum ObservationType
{
/// <summary>
/// Collected observations are generic.
/// </summary>
Default = 0,
/// <summary>
/// Collected observations contain goal information.
/// </summary>
Goal = 1,
/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,
/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
}
/// <summary>
/// Returns the size of the observations that will be generated.
/// For example, a sensor that observes the velocity of a rigid body (in 3D) would return
/// new {3}. A sensor that returns an RGB image would return new [] {Height, Width, 3}
/// Returns a description of the observations that will be generated by the sensor.
/// See <see cref="ObservationSpec"/> for more details, and helper methods to create one.
/// <returns>Size of the observations that will be generated.</returns>
int[] GetObservationShape();
/// <returns></returns>
ObservationSpec GetObservationSpec();
/// <summary>
/// Write the observation data directly to the <see cref="ObservationWriter"/>.

/// <returns></returns>
public static int ObservationSize(this ISensor sensor)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
foreach (var dim in shape)
for (var i = 0; i < obsSpec.Rank; i++)
count *= dim;
count *= obsSpec.Shape[i];
}
return count;

13
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="observationSpec">ObservationSpec of the observation to be written</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset)
{
SetTarget(data, observationSpec.Shape, offset);
}
/// <summary>
/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
internal void SetTarget(IList<float> data, int[] shape, int offset)
internal void SetTarget(IList<float> data, InplaceArray<int> shape, int offset)
{
m_Data = data;
m_Offset = offset;

8
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


public class RayPerceptionSensor : ISensor, IBuiltInSensor
{
float[] m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
RayPerceptionInput m_RayPerceptionInput;

void SetNumObservations(int numObservations)
{
m_Shape = new[] { numObservations };
m_ObservationSpec = ObservationSpec.Vector(numObservations);
m_Observations = new float[numObservations];
}

public void Reset() { }
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


// Cached sensor names and shapes.
string m_SensorName;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
int m_NumFloats;
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size)
{

m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_Shape = new[] { size };
m_ObservationSpec = ObservationSpec.Vector(size);
m_NumFloats = size;
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

return m_Shape[0];
return m_NumFloats;
}
internal abstract void WriteReflectedField(ObservationWriter writer);

8
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


RenderTexture m_RenderTexture;
bool m_Grayscale;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var height = renderTexture != null ? renderTexture.height : 0;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = new[] { height, width, grayscale ? 1 : 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

21
com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs


{
internal class SensorShapeValidator
{
List<int[]> m_SensorShapes;
List<ObservationSpec> m_SensorShapes;
/// <summary>
/// Check that the List Sensors are the same shape as the previous ones.

{
if (m_SensorShapes == null)
{
m_SensorShapes = new List<int[]>(sensors.Count);
m_SensorShapes = new List<ObservationSpec>(sensors.Count);
m_SensorShapes.Add(sensor.GetObservationShape());
m_SensorShapes.Add(sensor.GetObservationSpec());
}
}
else

);
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
{
var cachedShape = m_SensorShapes[i];
var sensorShape = sensors[i].GetObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++)
{
Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes must match.");
}
var cachedSpec = m_SensorShapes[i];
var sensorSpec = sensors[i].GetObservationSpec();
Debug.AssertFormat(
cachedSpec.Shape == sensorSpec.Shape,
"Sensor shapes must match. {0} != {1}",
cachedSpec.Shape,
sensorSpec.Shape
);
}
}
}

48
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


int m_UnstackedObservationSize;
string m_Name;
int[] m_Shape;
int[] m_WrappedShape;
private ObservationSpec m_ObservationSpec;
private ObservationSpec m_WrappedSpec;
/// <summary>
/// Buffer of previous observations

m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
m_WrappedShape = wrapped.GetObservationShape();
m_Shape = new int[m_WrappedShape.Length];
m_WrappedSpec = wrapped.GetObservationSpec();
for (int d = 0; d < m_WrappedShape.Length; d++)
{
m_Shape[d] = m_WrappedShape[d];
}
// Set up the cached observation spec for the StackingSensor
var newShape = m_WrappedSpec.Shape;
m_Shape[m_Shape.Length - 1] *= numStackedObservations;
newShape[newShape.Length - 1] *= numStackedObservations;
m_ObservationSpec = new ObservationSpec(
newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType
);
// Initialize uncompressed buffer anyway in case python trainer does not
// support the compression mapping and has to fall back to uncompressed obs.

m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
}
if (m_Shape.Length != 1)
if (m_WrappedSpec.Rank != 1)
m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
var wrappedShape = m_WrappedSpec.Shape;
m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]);
}
}

// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0);
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0);
if (m_WrappedShape.Length == 1)
if (m_WrappedSpec.Rank == 1)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{

for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
for (var h = 0; h < m_WrappedShape[0]; h++)
for (var h = 0; h < m_WrappedSpec.Shape[0]; h++)
for (var w = 0; w < m_WrappedShape[1]; w++)
for (var w = 0; w < m_WrappedSpec.Shape[1]; w++)
for (var c = 0; c < m_WrappedShape[2]; c++)
for (var c = 0; c < m_WrappedSpec.Shape[2]; c++)
writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
writer[h, w, i * m_WrappedSpec.Shape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations;
numWritten = m_WrappedSpec.Shape[0] * m_WrappedSpec.Shape[1] * m_WrappedSpec.Shape[2] * m_NumStackedObservations;
}
return numWritten;

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

/// </summary>
internal byte[] CreateEmptyPNG()
{
int height = m_WrappedSensor.GetObservationShape()[0];
int width = m_WrappedSensor.GetObservationShape()[1];
var shape = m_WrappedSpec.Shape;
int height = shape[0];
int width = shape[1];
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
Color32[] resetColorArray = texture2D.GetPixels32();
Color32 black = new Color32(0, 0, 0, 0);

// wrapped sensor doesn't have one, use default mapping.
// Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
int[] wrappedMapping = null;
int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2];
int wrappedNumChannel = m_WrappedSpec.Shape[2];
var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
if (sparseChannelSensor != null)
{

10
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


// TODO use float[] instead
// TODO allow setting float[]
List<float> m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
/// <summary>

m_Observations = new List<float>(observationSize);
m_Name = name;
m_Shape = new[] { observationSize };
m_ObservationSpec = ObservationSpec.Vector(observationSize);
var expectedObservations = m_Shape[0];
var expectedObservations = m_ObservationSpec.Shape[0];
if (m_Observations.Count > expectedObservations)
{
// Too many observations, truncate

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

23
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Demonstrations;

class DummySensor : ISensor
{
public int[] Shape;
public ObservationSpec ObservationSpec;
public SensorCompressionType CompressionType;
internal DummySensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
}
public int Write(ObservationWriter writer)

foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var inplaceShape = InplaceArray<int>.FromList(shape);
dummySensor.Shape = shape;
if (shape.Length == 1)
{
dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]);
}
else if (shape.Length == 3)
{
dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]);
}
else
{
throw new ArgumentOutOfRangeException();
}
obsWriter.SetTarget(new float[128], shape, 0);
obsWriter.SetTarget(new float[128], inplaceShape, 0);
var caps = new UnityRLCapabilities
{

19
com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs


public override int[] GetObservationShape()
{
return Sensor.GetObservationShape();
var shape = Sensor.GetObservationSpec().Shape;
return new int[] { shape[0], shape[1], shape[2] };
public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor
public class Test3DSensor : ISensor, IBuiltInSensor
{
int m_Width;
int m_Height;

m_Name = name;
}
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new[] { m_Height, m_Width, m_Channels };
return ObservationSpec.Visual(m_Height, m_Width, m_Channels);
}
public int Write(ObservationWriter writer)

public BuiltInSensorType GetBuiltInSensorType()
{
return (BuiltInSensorType)k_BuiltInSensorType;
}
public DimensionProperty[] GetDimensionProperties()
{
return new[]
{
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None
};
}
}

4
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


sensorName = n;
}
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new[] { 0 };
return ObservationSpec.Vector(0);
}
public int Write(ObservationWriter writer)

6
com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs


{
var bufferSensor = new BufferSensor(20, 4, "testName");
var shape = bufferSensor.GetObservationShape();
var dimProp = bufferSensor.GetDimensionProperties();
var shape = bufferSensor.GetObservationSpec().Shape;
var dimProp = bufferSensor.GetObservationSpec().DimensionProperties;
Assert.AreEqual(shape[0], 20);
Assert.AreEqual(shape[1], 4);
Assert.AreEqual(shape.Length, 2);

var obsWriter = new ObservationWriter();
var obs = bufferSensor.GetObservationProto(obsWriter);
Assert.AreEqual(shape, obs.Shape);
Assert.AreEqual(shape, InplaceArray<int>.FromList(obs.Shape));
Assert.AreEqual(obs.DimensionProperties.Count, 2);
Assert.AreEqual((int)dimProp[0], obs.DimensionProperties[0]);
Assert.AreEqual((int)dimProp[1], obs.DimensionProperties[1]);

3
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());