浏览代码

[MLA-1634] Add ObservationSpec and update ISensor interfaces (#5127)

/v2-staging-rebase
GitHub 4 年前
当前提交
7c7d02ce
共有 56 个文件被更改,包括 1025 次插入368 次删除
  1. 2
      DevProject/Packages/manifest.json
  2. 2
      DevProject/Packages/packages-lock.json
  3. 18
      DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json
  4. 4
      DevProject/ProjectSettings/ProjectVersion.txt
  5. 4
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
  6. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  7. 5
      Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta
  8. 8
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
  9. 12
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  10. 30
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  11. 10
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  12. 12
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  13. 8
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs
  14. 6
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs
  15. 9
      com.unity.ml-agents/CHANGELOG.md
  16. 1
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  17. 7
      com.unity.ml-agents/Runtime/Analytics/Events.cs
  18. 55
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  19. 28
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  20. 3
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  21. 2
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  22. 4
      com.unity.ml-agents/Runtime/SensorHelper.cs
  23. 19
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  24. 32
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  25. 70
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  26. 13
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  27. 8
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  28. 12
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  29. 8
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  30. 21
      com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
  31. 48
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  32. 10
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  33. 23
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  34. 19
      com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
  35. 4
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  36. 6
      com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs
  37. 3
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  38. 13
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  39. 2
      com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
  40. 20
      com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs
  41. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
  42. 27
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  43. 30
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  44. 20
      docs/Migrating.md
  45. 2
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
  46. 241
      com.unity.ml-agents/Runtime/InplaceArray.cs
  47. 3
      com.unity.ml-agents/Runtime/InplaceArray.cs.meta
  48. 140
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  49. 3
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta
  50. 192
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs
  51. 78
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs
  52. 3
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta
  53. 31
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs
  54. 11
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta
  55. 47
      com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs
  56. 0
      /com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta

2
DevProject/Packages/manifest.json


"com.unity.package-manager-doctools": "1.7.0-preview",
"com.unity.package-validation-suite": "0.19.0-preview",
"com.unity.purchasing": "2.2.1",
"com.unity.test-framework": "1.1.20",
"com.unity.test-framework": "1.1.22",
"com.unity.test-framework.performance": "2.2.0-preview",
"com.unity.testtools.codecoverage": "1.0.0-pre.3",
"com.unity.textmeshpro": "2.0.1",

2
DevProject/Packages/packages-lock.json


"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
},
"com.unity.test-framework": {
"version": "1.1.20",
"version": "1.1.22",
"depth": 0,
"source": "registry",
"dependencies": {

18
DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json


"m_Name": "Settings",
"m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json",
"m_Dictionary": {
"m_DictionaryValues": []
"m_DictionaryValues": [
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "Path",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "HistoryPath",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "IncludeAssemblies",
"value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}"
}
]
}
}

4
DevProject/ProjectSettings/ProjectVersion.txt


m_EditorVersion: 2019.4.19f1
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec)
m_EditorVersion: 2019.4.20f1
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa)

4
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs


}
/// <inheritdoc/>
public override int[] GetObservationShape()
public override ObservationSpec GetObservationSpec()
return new[] { BasicController.k_Extents };
return ObservationSpec.Vector(BasicController.k_Extents);
}
/// <inheritdoc/>

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


public abstract void WriteObservation(float[] output);
/// <inheritdoc/>
public abstract int[] GetObservationShape();
public abstract ObservationSpec GetObservationSpec();
/// <inheritdoc/>
public abstract string GetName();

5
Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta


fileFormatVersion: 2
guid: 8cd4584c2f2cb4c5fb51675d364e10ec
ScriptedImporter:
internalIDToNameTable: []
fileIDToRecycleName:
11400000: main obj
11400002: model data
serializedVersion: 2
userData:
assetBundleName:
assetBundleVariant:

8
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs


{
Texture2D m_Texture;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var width = texture.width;
var height = texture.height;
m_Name = name;
m_Shape = new[] { height, width, 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

m_NumSpecialTypes = board.NumSpecialTypes;
m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

30
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


protected bool Initialized = false;
/// <summary>
/// Array holding the dimensions of the resulting tensor
/// Cached ObservationSpec
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
//
// Debug Parameters

// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();

CellActivity[i] = DebugDefaultColor;
}
}
}
/// <summary>Gets the shape of the grid observation</summary>
/// <returns>integer array shape of the grid observation</returns>
public int[] GetFloatObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}
/// <inheritdoc/>

/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
public ObservationSpec GetObservationSpec()
{
// Lazy update
var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}
/// <inheritdoc/>
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
/// <inheritdoc/>

10
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_SensorName;
PoseExtractor m_PoseExtractor;

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#if UNITY_2020_1_OR_NEWER

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

8
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

6
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
}
}

9
com.unity.ml-agents/CHANGELOG.md


## [Unreleased]
### Major Changes
#### com.unity.ml-agents (C#)
- Several breaking interface changes were made. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details.
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060)
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details. (#5060)
- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127)
#### ml-agents / ml-agents-envs / gym-unity (Python)
### Minor Changes

1
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


/// <summary>
/// Set whether or not the action index for the given branch is allowed.
/// </summary>
/// <remarks>
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond

7
com.unity.ml-agents/Runtime/Analytics/Events.cs


public static EventObservationSpec FromSensor(ISensor sensor)
{
var shape = sensor.GetObservationShape();
var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
var dimProps = obsSpec.DimensionProperties;
dimInfos[i].Flags = dimProps != null ? (int)dimProps[i] : 0;
dimInfos[i].Flags = (int)dimProps[i];
}
var builtInSensorType =

55
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


/// <returns></returns>
public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
ObservationProto observationProto = null;
var compressionType = sensor.GetCompressionType();
// Check capabilities if we need to concatenate PNGs

floatDataProto.Data.Add(0.0f);
}
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0);
sensor.Write(observationWriter);
observationProto = new ObservationProto

observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
// Add the dimension properties if any to the observationProto
var dimensionPropertySensor = sensor as IDimensionPropertiesSensor;
if (dimensionPropertySensor != null)
// Add the dimension properties to the observationProto
var dimensionProperties = obsSpec.DimensionProperties;
for (int i = 0; i < dimensionProperties.Length; i++)
{
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
}
// Checking trainer compatibility with variable length observations
if (dimensionProperties == new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None))
var dimensionProperties = dimensionPropertySensor.GetDimensionProperties();
for (int i = 0; i < dimensionProperties.Length; i++)
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
// Checking trainer compatibility with variable length observations
if (dimensionProperties.Length == 2)
{
if (dimensionProperties[0] == DimensionProperty.VariableSize &&
dimensionProperties[1] == DimensionProperty.None)
{
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
{
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
}
}
}
}
for (var i = 0; i < shape.Length; i++)
{
observationProto.Shape.Add(shape[i]);
observationProto.Shape.AddRange(shape);
var sensorName = sensor.GetName();
if (!string.IsNullOrEmpty(sensorName))
{

// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;
if (typeSensor != null)
{
observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
}
else
{
observationProto.ObservationType = ObservationTypeProto.Default;
}
observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType;
return observationProto;
}

28
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sensor = sensors[sensorIndex];
if (sensor.GetObservationShape().Length == 3)
if (sensor.GetObservationSpec().Shape.Length == 3)
{
if (!tensorsNames.Contains(
TensorNames.GetVisualObservationName(visObsIndex)))

}
visObsIndex++;
}
if (sensor.GetObservationShape().Length == 2)
if (sensor.GetObservationSpec().Shape.Length == 2)
{
if (!tensorsNames.Contains(
TensorNames.GetObservationName(sensorIndex)))

static FailedCheck CheckVisualObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];

static FailedCheck CheckRankTwoObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;

static FailedCheck CheckRankOneObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var dim1Bp = shape[0];
var dim1T = tensorProxy.Channels;
var dim2T = tensorProxy.Width;

for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
if (sens.GetObservationSpec().Shape.Length == 3)
{
tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] =

if (sens.GetObservationShape().Length == 2)
if (sens.GetObservationSpec().Shape.Length == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);

var totalVectorSensorSize = 0;
foreach (var sens in sensors)
{
if ((sens.GetObservationShape().Length == 1))
if ((sens.GetObservationSpec().Shape.Length == 1))
totalVectorSensorSize += sens.GetObservationShape()[0];
totalVectorSensorSize += sens.GetObservationSpec().Shape[0];
}
}

foreach (var sensorComp in sensors)
{
if (sensorComp.GetObservationShape().Length == 1)
if (sensorComp.GetObservationSpec().Shape.Length == 1)
var vecSize = sensorComp.GetObservationShape()[0];
var vecSize = sensorComp.GetObservationSpec().Shape[0];
if (sensorSizes.Length == 0)
{
sensorSizes = $"[{vecSize}";

for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
if (sens.GetObservationSpec().Rank == 3)
if (sens.GetObservationShape().Length == 2)
if (sens.GetObservationSpec().Rank == 2)
if (sens.GetObservationShape().Length == 1)
if (sens.GetObservationSpec().Rank == 1)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens);

3
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];
var shape = sensor.GetObservationShape();
var rank = shape.Length;
var rank = sensor.GetObservationSpec().Rank;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)

2
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


{
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationShape(), 0);
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0);
sensor.Write(m_ObservationWriter);
}
else

4
com.unity.ml-agents/Runtime/SensorHelper.cs


}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

19
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


/// <summary>
/// A Sensor that allows to observe a variable number of entities.
/// </summary>
public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor
public class BufferSensor : ISensor, IBuiltInSensor
{
private string m_Name;
private int m_MaxNumObs;

static DimensionProperty[] s_DimensionProperties = new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None
};
ObservationSpec m_ObservationSpec;
public BufferSensor(int maxNumberObs, int obsSize, string name)
{
m_Name = name;

m_CurrentNumObservables = 0;
m_ObservationSpec = ObservationSpec.VariableLength(m_MaxNumObs, m_ObsSize);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new int[] { m_MaxNumObs, m_ObsSize };
}
/// <inheritdoc/>
public DimensionProperty[] GetDimensionProperties()
{
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

32
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


/// <summary>
/// A sensor that wraps a Camera object to generate visual observations for an agent.
/// </summary>
public class CameraSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor
public class CameraSensor : ISensor, IBuiltInSensor
{
Camera m_Camera;
int m_Width;

int[] m_Shape;
private ObservationSpec m_ObservationSpec;
static DimensionProperty[] s_DimensionProperties = new DimensionProperty[] {
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None };
/// <summary>
/// The Camera used for rendering the sensor observations.

m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = GenerateShape(width, height, grayscale);
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_CompressionType = compression;
}

return m_Name;
}
/// <summary>
/// Accessor for the size of the sensor data. Will be h x w x 1 for grayscale and
/// h x w x 3 for color.
/// </summary>
/// <returns>Size of each of the three dimensions.</returns>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <summary>
/// Accessor for the dimension properties of a camera sensor. A camera sensor
/// Has translational equivariance along width and hight and no property along
/// the channels dimension.
/// </summary>
/// <returns></returns>
public DimensionProperty[] GetDimensionProperties()
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

70
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


}
/// <summary>
/// The Dimension property flags of the observations
/// </summary>
[System.Flags]
public enum DimensionProperty
{
/// <summary>
/// No properties specified.
/// </summary>
Unspecified = 0,
/// <summary>
/// No Property of the observation in that dimension. Observation can be processed with
/// fully connected networks.
/// </summary>
None = 1,
/// <summary>
/// Means it is suitable to do a convolution in this dimension.
/// </summary>
TranslationalEquivariance = 2,
/// <summary>
/// Means that there can be a variable number of observations in this dimension.
/// The observations are unordered.
/// </summary>
VariableSize = 4,
}
/// <summary>
/// The ObservationType enum of the Sensor.
/// </summary>
public enum ObservationType
{
/// <summary>
/// Collected observations are generic.
/// </summary>
Default = 0,
/// <summary>
/// Collected observations contain goal information.
/// </summary>
Goal = 1,
/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,
/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
}
/// <summary>
/// Returns the size of the observations that will be generated.
/// For example, a sensor that observes the velocity of a rigid body (in 3D) would return
/// new {3}. A sensor that returns an RGB image would return new [] {Height, Width, 3}
/// Returns a description of the observations that will be generated by the sensor.
/// See <see cref="ObservationSpec"/> for more details, and helper methods to create one.
/// <returns>Size of the observations that will be generated.</returns>
int[] GetObservationShape();
/// <returns></returns>
ObservationSpec GetObservationSpec();
/// <summary>
/// Write the observation data directly to the <see cref="ObservationWriter"/>.

/// <returns></returns>
public static int ObservationSize(this ISensor sensor)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
foreach (var dim in shape)
for (var i = 0; i < obsSpec.Rank; i++)
count *= dim;
count *= obsSpec.Shape[i];
}
return count;

13
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="observationSpec">ObservationSpec of the observation to be written</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset)
{
SetTarget(data, observationSpec.Shape, offset);
}
/// <summary>
/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
internal void SetTarget(IList<float> data, int[] shape, int offset)
internal void SetTarget(IList<float> data, InplaceArray<int> shape, int offset)
{
m_Data = data;
m_Offset = offset;

8
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


public class RayPerceptionSensor : ISensor, IBuiltInSensor
{
float[] m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
RayPerceptionInput m_RayPerceptionInput;

void SetNumObservations(int numObservations)
{
m_Shape = new[] { numObservations };
m_ObservationSpec = ObservationSpec.Vector(numObservations);
m_Observations = new float[numObservations];
}

public void Reset() { }
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


// Cached sensor names and shapes.
string m_SensorName;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
int m_NumFloats;
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size)
{

m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_Shape = new[] { size };
m_ObservationSpec = ObservationSpec.Vector(size);
m_NumFloats = size;
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

return m_Shape[0];
return m_NumFloats;
}
internal abstract void WriteReflectedField(ObservationWriter writer);

8
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


RenderTexture m_RenderTexture;
bool m_Grayscale;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var height = renderTexture != null ? renderTexture.height : 0;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = new[] { height, width, grayscale ? 1 : 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

21
com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs


{
internal class SensorShapeValidator
{
List<int[]> m_SensorShapes;
List<ObservationSpec> m_SensorShapes;
/// <summary>
/// Check that the List Sensors are the same shape as the previous ones.

{
if (m_SensorShapes == null)
{
m_SensorShapes = new List<int[]>(sensors.Count);
m_SensorShapes = new List<ObservationSpec>(sensors.Count);
m_SensorShapes.Add(sensor.GetObservationShape());
m_SensorShapes.Add(sensor.GetObservationSpec());
}
}
else

);
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
{
var cachedShape = m_SensorShapes[i];
var sensorShape = sensors[i].GetObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++)
{
Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes must match.");
}
var cachedSpec = m_SensorShapes[i];
var sensorSpec = sensors[i].GetObservationSpec();
Debug.AssertFormat(
cachedSpec.Shape == sensorSpec.Shape,
"Sensor shapes must match. {0} != {1}",
cachedSpec.Shape,
sensorSpec.Shape
);
}
}
}

48
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


int m_UnstackedObservationSize;
string m_Name;
int[] m_Shape;
int[] m_WrappedShape;
private ObservationSpec m_ObservationSpec;
private ObservationSpec m_WrappedSpec;
/// <summary>
/// Buffer of previous observations

m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
m_WrappedShape = wrapped.GetObservationShape();
m_Shape = new int[m_WrappedShape.Length];
m_WrappedSpec = wrapped.GetObservationSpec();
for (int d = 0; d < m_WrappedShape.Length; d++)
{
m_Shape[d] = m_WrappedShape[d];
}
// Set up the cached observation spec for the StackingSensor
var newShape = m_WrappedSpec.Shape;
m_Shape[m_Shape.Length - 1] *= numStackedObservations;
newShape[newShape.Length - 1] *= numStackedObservations;
m_ObservationSpec = new ObservationSpec(
newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType
);
// Initialize uncompressed buffer anyway in case python trainer does not
// support the compression mapping and has to fall back to uncompressed obs.

m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
}
if (m_Shape.Length != 1)
if (m_WrappedSpec.Rank != 1)
m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
var wrappedShape = m_WrappedSpec.Shape;
m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]);
}
}

// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0);
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0);
if (m_WrappedShape.Length == 1)
if (m_WrappedSpec.Rank == 1)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{

for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
for (var h = 0; h < m_WrappedShape[0]; h++)
for (var h = 0; h < m_WrappedSpec.Shape[0]; h++)
for (var w = 0; w < m_WrappedShape[1]; w++)
for (var w = 0; w < m_WrappedSpec.Shape[1]; w++)
for (var c = 0; c < m_WrappedShape[2]; c++)
for (var c = 0; c < m_WrappedSpec.Shape[2]; c++)
writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
writer[h, w, i * m_WrappedSpec.Shape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations;
numWritten = m_WrappedSpec.Shape[0] * m_WrappedSpec.Shape[1] * m_WrappedSpec.Shape[2] * m_NumStackedObservations;
}
return numWritten;

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

/// </summary>
internal byte[] CreateEmptyPNG()
{
int height = m_WrappedSensor.GetObservationShape()[0];
int width = m_WrappedSensor.GetObservationShape()[1];
var shape = m_WrappedSpec.Shape;
int height = shape[0];
int width = shape[1];
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
Color32[] resetColorArray = texture2D.GetPixels32();
Color32 black = new Color32(0, 0, 0, 0);

// wrapped sensor doesn't have one, use default mapping.
// Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
int[] wrappedMapping = null;
int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2];
int wrappedNumChannel = m_WrappedSpec.Shape[2];
var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
if (sparseChannelSensor != null)
{

10
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


// TODO use float[] instead
// TODO allow setting float[]
List<float> m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
/// <summary>

m_Observations = new List<float>(observationSize);
m_Name = name;
m_Shape = new[] { observationSize };
m_ObservationSpec = ObservationSpec.Vector(observationSize);
var expectedObservations = m_Shape[0];
var expectedObservations = m_ObservationSpec.Shape[0];
if (m_Observations.Count > expectedObservations)
{
// Too many observations, truncate

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

23
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Demonstrations;

class DummySensor : ISensor
{
public int[] Shape;
public ObservationSpec ObservationSpec;
public SensorCompressionType CompressionType;
internal DummySensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
}
public int Write(ObservationWriter writer)

foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var inplaceShape = InplaceArray<int>.FromList(shape);
dummySensor.Shape = shape;
if (shape.Length == 1)
{
dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]);
}
else if (shape.Length == 3)
{
dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]);
}
else
{
throw new ArgumentOutOfRangeException();
}
obsWriter.SetTarget(new float[128], shape, 0);
obsWriter.SetTarget(new float[128], inplaceShape, 0);
var caps = new UnityRLCapabilities
{

19
com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs


public override int[] GetObservationShape()
{
return Sensor.GetObservationShape();
var shape = Sensor.GetObservationSpec().Shape;
return new int[] { shape[0], shape[1], shape[2] };
public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor
public class Test3DSensor : ISensor, IBuiltInSensor
{
int m_Width;
int m_Height;

m_Name = name;
}
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new[] { m_Height, m_Width, m_Channels };
return ObservationSpec.Visual(m_Height, m_Width, m_Channels);
}
public int Write(ObservationWriter writer)

public BuiltInSensorType GetBuiltInSensorType()
{
return (BuiltInSensorType)k_BuiltInSensorType;
}
public DimensionProperty[] GetDimensionProperties()
{
return new[]
{
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None
};
}
}

4
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


sensorName = n;
}
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new[] { 0 };
return ObservationSpec.Vector(0);
}
public int Write(ObservationWriter writer)

6
com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs


{
var bufferSensor = new BufferSensor(20, 4, "testName");
var shape = bufferSensor.GetObservationShape();
var dimProp = bufferSensor.GetDimensionProperties();
var shape = bufferSensor.GetObservationSpec().Shape;
var dimProp = bufferSensor.GetObservationSpec().DimensionProperties;
Assert.AreEqual(shape[0], 20);
Assert.AreEqual(shape[1], 4);
Assert.AreEqual(shape.Length, 2);

var obsWriter = new ObservationWriter();
var obs = bufferSensor.GetObservationProto(obsWriter);
Assert.AreEqual(shape, obs.Shape);
Assert.AreEqual(shape, InplaceArray<int>.FromList(obs.Shape));
Assert.AreEqual(obs.DimensionProperties.Count, 2);
Assert.AreEqual((int)dimProp[0], obs.DimensionProperties[0]);
Assert.AreEqual((int)dimProp[1], obs.DimensionProperties[1]);

3
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
var sensor = cameraComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
var expectedShapeInplace = new InplaceArray<int>(height, width, grayscale ? 1 : 3);
Assert.AreEqual(expectedShapeInplace, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(CameraSensor), sensor.GetType());
}
}

13
com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs


public int Width { get; }
public int Height { get; }
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
public float[,] floatData;
public Float2DSensor(int width, int height, string name)

m_Name = name;
m_Shape = new[] { height, width, 1 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 1);
floatData = new float[Height, Width];
}

Height = floatData.GetLength(0);
Width = floatData.GetLength(1);
m_Name = name;
m_Shape = new[] { Height, Width, 1 };
m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1);
}
public string GetName()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
public byte[] GetCompressedObservation()

var output = new float[12];
var writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
sensor.Write(writer);
for (var i = 0; i < 9; i++)
{

2
com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs


{
ObservationWriter writer = new ObservationWriter();
var buffer = new[] { 0f, 0f, 0f };
var shape = new[] { 3 };
var shape = new InplaceArray<int>(3);
writer.SetTarget(buffer, shape, 0);
// Elementwise writes

20
com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs


var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

2
com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs


Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
var sensor = renderTexComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType());
}
}

27
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


using System.Collections.Generic;
using System.Text.RegularExpressions;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;

public class DummySensor : ISensor
{
string m_Name = "DummySensor";
int[] m_Shape;
ObservationSpec m_ObservationSpec;
m_Shape = new[] { dim1 };
m_ObservationSpec = ObservationSpec.Vector(dim1);
m_Shape = new[] { dim1, dim2, };
m_ObservationSpec = ObservationSpec.VariableLength(dim1, dim2);
m_Shape = new[] { dim1, dim2, dim3 };
m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3);
}
public string GetName()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
public byte[] GetCompressedObservation()

validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}

validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}

var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order

LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}
}

30
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


ISensor wrapped = new VectorSensor(4);
ISensor sensor = new StackingSensor(wrapped, 4);
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
Assert.AreEqual(sensor.GetObservationShape(), new[] { 16 });
Assert.AreEqual(sensor.GetObservationSpec().Shape, new InplaceArray<int>(16));
}
[Test]

{
public SensorCompressionType CompressionType = SensorCompressionType.PNG;
public int[] Mapping;
public int[] Shape;
public ObservationSpec ObservationSpec;
public float[,,] CurrentObservation;
internal Dummy3DSensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
for (var h = 0; h < Shape[0]; h++)
for (var h = 0; h < ObservationSpec.Shape[0]; h++)
for (var w = 0; w < Shape[1]; w++)
for (var w = 0; w < ObservationSpec.Shape[1]; w++)
for (var c = 0; c < Shape[2]; c++)
for (var c = 0; c < ObservationSpec.Shape[2]; c++)
return Shape[0] * Shape[1] * Shape[2];
return ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2];
var flattenedObservation = new float[Shape[0] * Shape[1] * Shape[2]];
writer.SetTarget(flattenedObservation, Shape, 0);
var flattenedObservation = new float[ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]];
writer.SetTarget(flattenedObservation, ObservationSpec.Shape, 0);
Write(writer);
byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z);
return bytes;

// Test mapping with number of layers not being multiple of 3
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new[] { 2, 2, 4 };
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

paddedDummySensor.Shape = new[] { 2, 2, 4 };
paddedDummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 };
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

public void Test3DStacking()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new[] { 2, 1, 2 };
wrapped.ObservationSpec = ObservationSpec.Visual(2, 1, 2);
var sensor = new StackingSensor(wrapped, 2);
// Check the stacking is on the last dimension

public void TestStackedGetCompressedObservation()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new[] { 1, 1, 3 };
wrapped.ObservationSpec = ObservationSpec.Visual(1, 1, 3);
var sensor = new StackingSensor(wrapped, 2);
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } };

public void TestStackingSensorBuiltInSensorType()
{
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new[] { 2, 2, 4 };
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown);

20
docs/Migrating.md


- The `IActuator` interface now implements `IHeuristicProvider`. Please add the corresponding `Heuristic(in ActionBuffers)`
method to your custom Actuator classes.
- The `ISensor.GetObservationShape()` method was removed, and `GetObservationSpec()` was added. You can use
`ObservationSpec.Vector()` or `ObservationSpec.Visual()` to generate `ObservationSpec`s that are equivalent to
the previous shape. For example, if your old ISensor looked like:
```csharp
public override int[] GetObservationShape()
{
return new[] { m_Height, m_Width, m_NumChannels };
}
```
the equivalent code would now be
```csharp
public override ObservationSpec GetObservationSpec()
{
return ObservationSpec.Visual(m_Height, m_Width, m_NumChannels);
}
```
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations
- If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator

2
com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta


fileFormatVersion: 2
guid: 297e9ec12d6de45adbcf6dea1a9de019
guid: 8e1cdc27e533749fabc04b3cdeb93501
MonoImporter:
externalObjects: {}
serializedVersion: 2

241
com.unity.ml-agents/Runtime/InplaceArray.cs


using System;
using System.Collections.Generic;
namespace Unity.MLAgents
{
/// <summary>
/// An array-like object that stores up to four elements.
/// This is a value type that does not allocate any additional memory.
/// </summary>
/// <remarks>
/// This does not implement any interfaces such as IList, in order to avoid any accidental boxing allocations.
/// </remarks>
/// <typeparam name="T"></typeparam>
public struct InplaceArray<T> : IEquatable<InplaceArray<T>> where T : struct
{
private const int k_MaxLength = 4;
private readonly int m_Length;
private T m_Elem0;
private T m_Elem1;
private T m_Elem2;
private T m_Elem3;
/// <summary>
/// Create a length-1 array.
/// </summary>
/// <param name="elem0"></param>
public InplaceArray(T elem0)
{
m_Length = 1;
m_Elem0 = elem0;
m_Elem1 = new T { };
m_Elem2 = new T { };
m_Elem3 = new T { };
}
/// <summary>
/// Create a length-2 array.
/// </summary>
/// <param name="elem0"></param>
/// <param name="elem1"></param>
public InplaceArray(T elem0, T elem1)
{
m_Length = 2;
m_Elem0 = elem0;
m_Elem1 = elem1;
m_Elem2 = new T { };
m_Elem3 = new T { };
}
/// <summary>
/// Create a length-3 array.
/// </summary>
/// <param name="elem0"></param>
/// <param name="elem1"></param>
/// <param name="elem2"></param>
public InplaceArray(T elem0, T elem1, T elem2)
{
m_Length = 3;
m_Elem0 = elem0;
m_Elem1 = elem1;
m_Elem2 = elem2;
m_Elem3 = new T { };
}
/// <summary>
/// Create a length-3 array.
/// </summary>
/// <param name="elem0"></param>
/// <param name="elem1"></param>
/// <param name="elem2"></param>
/// <param name="elem3"></param>
public InplaceArray(T elem0, T elem1, T elem2, T elem3)
{
m_Length = 4;
m_Elem0 = elem0;
m_Elem1 = elem1;
m_Elem2 = elem2;
m_Elem3 = elem3;
}
/// <summary>
/// Construct an InplaceArray from an IList (e.g. Array or List).
/// The source must be non-empty and have at most 4 elements.
/// </summary>
/// <param name="elems"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public static InplaceArray<T> FromList(IList<T> elems)
{
switch (elems.Count)
{
case 1:
return new InplaceArray<T>(elems[0]);
case 2:
return new InplaceArray<T>(elems[0], elems[1]);
case 3:
return new InplaceArray<T>(elems[0], elems[1], elems[2]);
case 4:
return new InplaceArray<T>(elems[0], elems[1], elems[2], elems[3]);
default:
throw new ArgumentOutOfRangeException();
}
}
/// <summary>
/// Per-element access.
/// </summary>
/// <param name="index"></param>
/// <exception cref="IndexOutOfRangeException"></exception>
public T this[int index]
{
get
{
if (index >= Length)
{
throw new IndexOutOfRangeException();
}
switch (index)
{
case 0:
return m_Elem0;
case 1:
return m_Elem1;
case 2:
return m_Elem2;
case 3:
return m_Elem3;
default:
throw new IndexOutOfRangeException();
}
}
set
{
if (index >= Length)
{
throw new IndexOutOfRangeException();
}
switch (index)
{
case 0:
m_Elem0 = value;
break;
case 1:
m_Elem1 = value;
break;
case 2:
m_Elem2 = value;
break;
case 3:
m_Elem3 = value;
break;
default:
throw new IndexOutOfRangeException();
}
}
}
/// <summary>
/// The length of the array.
/// </summary>
public int Length
{
get => m_Length;
}
/// <summary>
/// Returns a string representation of the array's elements.
/// </summary>
/// <returns></returns>
/// <exception cref="IndexOutOfRangeException"></exception>
public override string ToString()
{
switch (m_Length)
{
case 1:
return $"[{m_Elem0}]";
case 2:
return $"[{m_Elem0}, {m_Elem1}]";
case 3:
return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}]";
case 4:
return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}, {m_Elem3}]";
default:
throw new IndexOutOfRangeException();
}
}
/// <summary>
/// Check that the arrays have the same length and have all equal values.
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>Whether the arrays are equivalent.</returns>
public static bool operator ==(InplaceArray<T> lhs, InplaceArray<T> rhs)
{
return lhs.Equals(rhs);
}
/// <summary>
/// Check that the arrays are not equivalent.
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>Whether the arrays are not equivalent</returns>
public static bool operator !=(InplaceArray<T> lhs, InplaceArray<T> rhs) => !lhs.Equals(rhs);
/// <summary>
/// Check that the arrays are equivalent.
/// </summary>
/// <param name="other"></param>
/// <returns>Whether the arrays are not equivalent</returns>
public override bool Equals(object other) => other is InplaceArray<T> other1 && this.Equals(other1);
/// <summary>
/// Check that the arrays are equivalent.
/// </summary>
/// <param name="other"></param>
/// <returns>Whether the arrays are not equivalent</returns>
public bool Equals(InplaceArray<T> other)
{
// See https://montemagno.com/optimizing-c-struct-equality-with-iequatable/
var thisTuple = (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length);
var otherTuple = (other.m_Elem0, other.m_Elem1, other.m_Elem2, other.m_Elem3, other.Length);
return thisTuple.Equals(otherTuple);
}
/// <summary>
/// Get a hashcode for the array.
/// </summary>
/// <returns></returns>
public override int GetHashCode()
{
return (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length).GetHashCode();
}
}
}

3
com.unity.ml-agents/Runtime/InplaceArray.cs.meta


fileFormatVersion: 2
guid: c1a80abee18a41c8aee89aeb33f5985d
timeCreated: 1615506199

140
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs


using Unity.Barracuda;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A description of the observations that an ISensor produces.
/// This includes the size of the observation, the properties of each dimension, and how the observation
/// should be used for training.
/// </summary>
public struct ObservationSpec
{
internal readonly InplaceArray<int> m_Shape;
/// <summary>
/// The size of the observations that will be generated.
/// For example, a sensor that observes the velocity of a rigid body (in 3D) would use [3].
/// A sensor that returns an RGB image would use [Height, Width, 3].
/// </summary>
public InplaceArray<int> Shape
{
get => m_Shape;
}
internal readonly InplaceArray<DimensionProperty> m_DimensionProperties;
/// <summary>
/// The properties of each dimensions of the observation.
/// The length of the array must be equal to the rank of the observation tensor.
/// </summary>
/// <remarks>
/// It is generally recommended to use default values provided by helper functions,
/// as not all combinations of DimensionProperty may be supported by the trainer.
/// </remarks>
public InplaceArray<DimensionProperty> DimensionProperties
{
get => m_DimensionProperties;
}
internal ObservationType m_ObservationType;
/// <summary>
/// The type of the observation, e.g. whether they are generic or
/// help determine the goal for the Agent.
/// </summary>
public ObservationType ObservationType
{
get => m_ObservationType;
}
/// <summary>
/// The number of dimensions of the observation.
/// </summary>
public int Rank
{
get { return Shape.Length; }
}
/// <summary>
/// Construct an ObservationSpec for 1-D observations of the requested length.
/// </summary>
/// <param name="length"></param>
/// <param name="obsType"></param>
/// <returns></returns>
public static ObservationSpec Vector(int length, ObservationType obsType = ObservationType.Default)
{
return new ObservationSpec(
new InplaceArray<int>(length),
new InplaceArray<DimensionProperty>(DimensionProperty.None),
obsType
);
}
/// <summary>
/// Construct an ObservationSpec for variable-length observations.
/// </summary>
/// <param name="obsSize"></param>
/// <param name="maxNumObs"></param>
/// <returns></returns>
public static ObservationSpec VariableLength(int obsSize, int maxNumObs)
{
var dimProps = new InplaceArray<DimensionProperty>(
DimensionProperty.VariableSize,
DimensionProperty.None
);
return new ObservationSpec(
new InplaceArray<int>(obsSize, maxNumObs),
dimProps
);
}
/// <summary>
/// Construct an ObservationSpec for visual-like observations, e.g. observations
/// with a height, width, and possible multiple channels.
/// </summary>
/// <param name="height"></param>
/// <param name="width"></param>
/// <param name="channels"></param>
/// <param name="obsType"></param>
/// <returns></returns>
public static ObservationSpec Visual(int height, int width, int channels, ObservationType obsType = ObservationType.Default)
{
var dimProps = new InplaceArray<DimensionProperty>(
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None
);
return new ObservationSpec(
new InplaceArray<int>(height, width, channels),
dimProps,
obsType
);
}
/// <summary>
/// Create a general ObservationSpec from the shape, dimension properties, and observation type.
/// </summary>
/// <remarks>
/// Note that not all combinations of DimensionProperty may be supported by the trainer.
/// shape and dimensionProperties must have the same size.
/// </remarks>
/// <param name="shape"></param>
/// <param name="dimensionProperties"></param>
/// <param name="observationType"></param>
/// <exception cref="UnityAgentsException"></exception>
public ObservationSpec(
InplaceArray<int> shape,
InplaceArray<DimensionProperty> dimensionProperties,
ObservationType observationType = ObservationType.Default
)
{
if (shape.Length != dimensionProperties.Length)
{
throw new UnityAgentsException("shape and dimensionProperties must have the same length.");
}
m_Shape = shape;
m_DimensionProperties = dimensionProperties;
m_ObservationType = observationType;
}
}
}

3
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta


fileFormatVersion: 2
guid: cc1734d60fd5485ead94247cb206aa35
timeCreated: 1615412644

192
com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs


using System;
using System.Collections;
using NUnit.Framework;
using Unity.MLAgents;
using UnityEngine;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class InplaceArrayTests
{
class LengthCases : IEnumerable
{
public IEnumerator GetEnumerator()
{
yield return 1;
yield return 2;
yield return 3;
yield return 4;
}
}
private InplaceArray<int> GetTestArray(int length)
{
switch (length)
{
case 1:
return new InplaceArray<int>(11);
case 2:
return new InplaceArray<int>(11, 22);
case 3:
return new InplaceArray<int>(11, 22, 33);
case 4:
return new InplaceArray<int>(11, 22, 33, 44);
default:
throw new ArgumentException("bad test!");
}
}
private InplaceArray<int> GetZeroArray(int length)
{
switch (length)
{
case 1:
return new InplaceArray<int>(0);
case 2:
return new InplaceArray<int>(0, 0);
case 3:
return new InplaceArray<int>(0, 0, 0);
case 4:
return new InplaceArray<int>(0, 0, 0, 0);
default:
throw new ArgumentException("bad test!");
}
}
[Test]
public void TestInplaceArrayCtor()
{
var a1 = new InplaceArray<int>(11);
Assert.AreEqual(1, a1.Length);
Assert.AreEqual(11, a1[0]);
var a2 = new InplaceArray<int>(11, 22);
Assert.AreEqual(2, a2.Length);
Assert.AreEqual(11, a2[0]);
Assert.AreEqual(22, a2[1]);
var a3 = new InplaceArray<int>(11, 22, 33);
Assert.AreEqual(3, a3.Length);
Assert.AreEqual(11, a3[0]);
Assert.AreEqual(22, a3[1]);
Assert.AreEqual(33, a3[2]);
var a4 = new InplaceArray<int>(11, 22, 33, 44);
Assert.AreEqual(4, a4.Length);
Assert.AreEqual(11, a4[0]);
Assert.AreEqual(22, a4[1]);
Assert.AreEqual(33, a4[2]);
Assert.AreEqual(44, a4[3]);
}
[TestCaseSource(typeof(LengthCases))]
public void TestInplaceGetSet(int length)
{
var original = GetTestArray(length);
for (var i = 0; i < original.Length; i++)
{
var modified = original;
modified[i] = 0;
for (var j = 0; j < original.Length; j++)
{
if (i == j)
{
// This is the one we overwrote
Assert.AreEqual(0, modified[j]);
}
else
{
// Other elements should be unchanged
Assert.AreEqual(original[j], modified[j]);
}
}
}
}
[TestCaseSource(typeof(LengthCases))]
public void TestInvalidAccess(int length)
{
var tmp = 0;
var a = GetTestArray(length);
// get
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[-1]; });
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[length]; });
// set
Assert.Throws<IndexOutOfRangeException>(() => { a[-1] = 0; });
Assert.Throws<IndexOutOfRangeException>(() => { a[length] = 0; });
// Make sure temp is used
Assert.AreEqual(0, tmp);
}
[Test]
public void TestOperatorEqualsDifferentLengths()
{
// Check arrays of different length are never equal (even if they have 0s in all elements)
for (var l1 = 1; l1 <= 4; l1++)
{
var a1 = GetZeroArray(l1);
for (var l2 = 1; l2 <= 4; l2++)
{
var a2 = GetZeroArray(l2);
if (l1 == l2)
{
Assert.AreEqual(a1, a2);
Assert.IsTrue(a1 == a2);
}
else
{
Assert.AreNotEqual(a1, a2);
Assert.IsTrue(a1 != a2);
}
}
}
}
[TestCaseSource(typeof(LengthCases))]
public void TestOperatorEquals(int length)
{
for (var index = 0; index < length; index++)
{
var a1 = GetTestArray(length);
var a2 = GetTestArray(length);
Assert.AreEqual(a1, a2);
Assert.IsTrue(a1 == a2);
a1[index] = 42;
Assert.AreNotEqual(a1, a2);
Assert.IsTrue(a1 != a2);
a2[index] = 42;
Assert.AreEqual(a1, a2);
Assert.IsTrue(a1 == a2);
}
}
[Test]
public void TestToString()
{
Assert.AreEqual("[1]", new InplaceArray<int>(1).ToString());
Assert.AreEqual("[1, 2]", new InplaceArray<int>(1, 2).ToString());
Assert.AreEqual("[1, 2, 3]", new InplaceArray<int>(1, 2, 3).ToString());
Assert.AreEqual("[1, 2, 3, 4]", new InplaceArray<int>(1, 2, 3, 4).ToString());
}
[TestCaseSource(typeof(LengthCases))]
public void TestFromList(int length)
{
var intArray = new int[length];
for (var i = 0; i < length; i++)
{
intArray[i] = (i + 1) * 11; // 11, 22, etc.
}
var converted = InplaceArray<int>.FromList(intArray);
Assert.AreEqual(GetTestArray(length), converted);
}
}
}

78
com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs


using NUnit.Framework;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class ObservationSpecTests
{
[Test]
public void TestVectorObsSpec()
{
var obsSpec = ObservationSpec.Vector(5);
Assert.AreEqual(1, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(1, shape.Length);
Assert.AreEqual(5, shape[0]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(1, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.None, dimensionProps[0]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestVariableLengthObsSpec()
{
var obsSpec = ObservationSpec.VariableLength(5, 6);
Assert.AreEqual(2, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(2, shape.Length);
Assert.AreEqual(5, shape[0]);
Assert.AreEqual(6, shape[1]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(2, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]);
Assert.AreEqual(DimensionProperty.None, dimensionProps[1]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestVisualObsSpec()
{
var obsSpec = ObservationSpec.Visual(5, 6, 7);
Assert.AreEqual(3, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(3, shape.Length);
Assert.AreEqual(5, shape[0]);
Assert.AreEqual(6, shape[1]);
Assert.AreEqual(7, shape[2]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(3, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]);
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]);
Assert.AreEqual(DimensionProperty.None, dimensionProps[2]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestMismatchShapeDimensionPropThrows()
{
var shape = new InplaceArray<int>(1, 2);
var dimProps = new InplaceArray<DimensionProperty>(DimensionProperty.TranslationalEquivariance);
Assert.Throws<UnityAgentsException>(() =>
{
new ObservationSpec(shape, dimProps);
});
}
}
}

3
com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta


fileFormatVersion: 2
guid: 27ff1979bd5e4b8ebeb4d98f414a5090
timeCreated: 1615863866

31
com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs


namespace Unity.MLAgents.Sensors
{
/// <summary>
/// The ObservationType enum of the Sensor.
/// </summary>
internal enum ObservationType
{
// Collected observations are generic.
Default = 0,
// Collected observations contain goal information.
Goal = 1,
// Collected observations contain reward information.
Reward = 2,
// Collected observations are messages from other agents.
Message = 3,
}
/// <summary>
/// Sensor interface for sensors with variable types.
/// </summary>
internal interface ITypedSensor
{
/// <summary>
/// Returns the ObservationType enum corresponding to the type of the sensor.
/// </summary>
/// <returns>The ObservationType enum</returns>
ObservationType GetObservationType();
}
}

11
com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta


fileFormatVersion: 2
guid: 3751edac8122c411dbaef8f1b7043b82
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

47
com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs


namespace Unity.MLAgents.Sensors
{
/// <summary>
/// The Dimension property flags of the observations
/// </summary>
[System.Flags]
public enum DimensionProperty
{
/// <summary>
/// No properties specified.
/// </summary>
Unspecified = 0,
/// <summary>
/// No Property of the observation in that dimension. Observation can be processed with
/// fully connected networks.
/// </summary>
None = 1,
/// <summary>
/// Means it is suitable to do a convolution in this dimension.
/// </summary>
TranslationalEquivariance = 2,
/// <summary>
/// Means that there can be a variable number of observations in this dimension.
/// The observations are unordered.
/// </summary>
VariableSize = 4,
}
/// <summary>
/// Sensor interface for sensors with special dimension properties.
/// </summary>
internal interface IDimensionPropertiesSensor
{
/// <summary>
/// Returns the array containing the properties of each dimensions of the
/// observation. The length of the array must be equal to the rank of the
/// observation tensor.
/// </summary>
/// <returns>The array of DimensionProperty</returns>
DimensionProperty[] GetDimensionProperties();
}
}

/com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs.meta → /com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta

正在加载...
取消
保存