浏览代码

[MLA-1634] Add ObservationSpec and update ISensor interfaces (#5127)

/v2-staging-rebase
Christopher Goy 4 年前
当前提交
4d7ce41b
共有 55 个文件被更改,包括 1139 次插入364 次删除
  1. 2
      DevProject/Packages/manifest.json
  2. 8
      DevProject/Packages/packages-lock.json
  3. 18
      DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json
  4. 4
      DevProject/ProjectSettings/ProjectVersion.txt
  5. 4
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
  6. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  7. 8
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
  8. 12
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  9. 30
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  10. 10
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  11. 12
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  12. 8
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs
  13. 6
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs
  14. 3
      com.unity.ml-agents/CHANGELOG.md
  15. 1
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  16. 7
      com.unity.ml-agents/Runtime/Analytics/Events.cs
  17. 53
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  18. 95
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  19. 53
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  20. 2
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  21. 4
      com.unity.ml-agents/Runtime/SensorHelper.cs
  22. 19
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  23. 32
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  24. 70
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  25. 13
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  26. 8
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  27. 12
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  28. 8
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  29. 21
      com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
  30. 48
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  31. 10
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  32. 23
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  33. 4
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  34. 19
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  35. 6
      com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs
  36. 3
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  37. 13
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  38. 2
      com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
  39. 20
      com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs
  40. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
  41. 27
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  42. 30
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  43. 20
      docs/Migrating.md
  44. 2
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
  45. 241
      com.unity.ml-agents/Runtime/InplaceArray.cs
  46. 3
      com.unity.ml-agents/Runtime/InplaceArray.cs.meta
  47. 140
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  48. 3
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta
  49. 192
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs
  50. 78
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs
  51. 3
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta
  52. 31
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs
  53. 11
      com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta
  54. 47
      com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs
  55. 0
      /com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta

2
DevProject/Packages/manifest.json


"com.unity.package-manager-doctools": "1.7.0-preview",
"com.unity.package-validation-suite": "0.19.0-preview",
"com.unity.purchasing": "2.2.1",
"com.unity.test-framework": "1.1.20",
"com.unity.test-framework": "1.1.22",
"com.unity.test-framework.performance": "2.2.0-preview",
"com.unity.testtools.codecoverage": "1.0.0-pre.3",
"com.unity.textmeshpro": "2.0.1",

8
DevProject/Packages/packages-lock.json


"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
},
"com.unity.barracuda": {
"version": "1.3.0-preview",
"version": "1.3.1-preview",
"depth": 1,
"source": "registry",
"dependencies": {

"depth": 0,
"source": "local",
"dependencies": {
"com.unity.barracuda": "1.3.0-preview",
"com.unity.barracuda": "1.3.1-preview",
"com.unity.modules.imageconversion": "1.0.0",
"com.unity.modules.jsonserialize": "1.0.0",
"com.unity.modules.physics": "1.0.0",

"depth": 0,
"source": "local",
"dependencies": {
"com.unity.ml-agents": "1.7.2-preview"
"com.unity.ml-agents": "1.8.0-preview"
}
},
"com.unity.multiplayer-hlapi": {

"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
},
"com.unity.test-framework": {
"version": "1.1.20",
"version": "1.1.22",
"depth": 0,
"source": "registry",
"dependencies": {

18
DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json


"m_Name": "Settings",
"m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json",
"m_Dictionary": {
"m_DictionaryValues": []
"m_DictionaryValues": [
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "Path",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "HistoryPath",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "IncludeAssemblies",
"value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}"
}
]
}
}

4
DevProject/ProjectSettings/ProjectVersion.txt


m_EditorVersion: 2019.4.19f1
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec)
m_EditorVersion: 2019.4.20f1
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa)

4
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs


}
/// <inheritdoc/>
public override int[] GetObservationShape()
public override ObservationSpec GetObservationSpec()
return new[] { BasicController.k_Extents };
return ObservationSpec.Vector(BasicController.k_Extents);
}
/// <inheritdoc/>

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


public abstract void WriteObservation(float[] output);
/// <inheritdoc/>
public abstract int[] GetObservationShape();
public abstract ObservationSpec GetObservationSpec();
/// <inheritdoc/>
public abstract string GetName();

8
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs


{
Texture2D m_Texture;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var width = texture.width;
var height = texture.height;
m_Name = name;
m_Shape = new[] { height, width, 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

m_NumSpecialTypes = board.NumSpecialTypes;
m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

30
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


protected bool Initialized = false;
/// <summary>
/// Array holding the dimensions of the resulting tensor
/// Cached ObservationSpec
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
//
// Debug Parameters

// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();

CellActivity[i] = DebugDefaultColor;
}
}
}
/// <summary>Gets the shape of the grid observation</summary>
/// <returns>integer array shape of the grid observation</returns>
public int[] GetFloatObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}
/// <inheritdoc/>

/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
public ObservationSpec GetObservationSpec()
{
// Lazy update
var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}
/// <inheritdoc/>
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
/// <inheritdoc/>

10
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_SensorName;
PoseExtractor m_PoseExtractor;

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#if UNITY_2020_1_OR_NEWER

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

8
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

6
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
}
}

3
com.unity.ml-agents/CHANGELOG.md


## [Unreleased]
### Major Changes
#### com.unity.ml-agents (C#)
- Several breaking interface changes were made. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details.
- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart.
- The interface for disabling discrete actions in `IDiscreteActionMask` has changed.
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with

1
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


/// <summary>
/// Set whether or not the action index for the given branch is allowed.
/// </summary>
/// <remarks>
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond

7
com.unity.ml-agents/Runtime/Analytics/Events.cs


public static EventObservationSpec FromSensor(ISensor sensor)
{
var shape = sensor.GetObservationShape();
var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
var dimProps = obsSpec.DimensionProperties;
dimInfos[i].Flags = dimProps != null ? (int)dimProps[i] : 0;
dimInfos[i].Flags = (int)dimProps[i];
}
var builtInSensorType =

53
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


/// <returns></returns>
public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
ObservationProto observationProto = null;
var compressionType = sensor.GetCompressionType();
// Check capabilities if we need to concatenate PNGs

floatDataProto.Data.Add(0.0f);
}
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0);
sensor.Write(observationWriter);
observationProto = new ObservationProto

observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
// Add the dimension properties if any to the observationProto
var dimensionPropertySensor = sensor as IDimensionPropertiesSensor;
if (dimensionPropertySensor != null)
// Add the dimension properties to the observationProto
var dimensionProperties = obsSpec.DimensionProperties;
for (int i = 0; i < dimensionProperties.Length; i++)
var dimensionProperties = dimensionPropertySensor.GetDimensionProperties();
int[] intDimensionProperties = new int[dimensionProperties.Length];
for (int i = 0; i < dimensionProperties.Length; i++)
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
}
// Checking trainer compatibility with variable length observations
if (dimensionProperties == new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None))
{
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
}
// Checking trainer compatibility with variable length observations
if (dimensionProperties.Length == 2)
{
if (dimensionProperties[0] == DimensionProperty.VariableSize &&
dimensionProperties[1] == DimensionProperty.None)
{
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
{
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
}
}
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
observationProto.Shape.AddRange(shape);
// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;
if (typeSensor != null)
for (var i = 0; i < shape.Length; i++)
observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
observationProto.Shape.Add(shape[i]);
else
var sensorName = sensor.GetName();
if (!string.IsNullOrEmpty(sensorName))
observationProto.ObservationType = ObservationTypeProto.Default;
observationProto.Name = sensorName;
observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType;
return observationProto;
}

95
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sensor = sensors[sensorIndex];
if (sensor.GetObservationShape().Length == 3)
if (sensor.GetObservationSpec().Shape.Length == 3)
{
if (!tensorsNames.Contains(
TensorNames.GetVisualObservationName(visObsIndex)))

}
visObsIndex++;
}
if (sensor.GetObservationShape().Length == 2)
if (sensor.GetObservationSpec().Shape.Length == 2)
{
if (!tensorsNames.Contains(
TensorNames.GetObservationName(sensorIndex)))

static string CheckVisualObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];

static string CheckRankTwoObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;

for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
if (sens.GetObservationSpec().Shape.Length == 3)
{
tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] =

if (sens.GetObservationShape().Length == 2)
if (sens.GetObservationSpec().Shape.Length == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);

var totalVectorSensorSize = 0;
foreach (var sens in sensors)
{
if ((sens.GetObservationShape().Length == 1))
if ((sens.GetObservationSpec().Shape.Length == 1))
totalVectorSensorSize += sens.GetObservationShape()[0];
totalVectorSensorSize += sens.GetObservationSpec().Shape[0];
}
}

foreach (var sensorComp in sensors)
{
if (sensorComp.GetObservationShape().Length == 1)
if (sensorComp.GetObservationSpec().Shape.Length == 1)
var vecSize = sensorComp.GetObservationShape()[0];
var vecSize = sensorComp.GetObservationSpec().Shape[0];
if (sensorSizes.Length == 0)
{
sensorSizes = $"[{vecSize}";

return null;
}
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>A IEnumerable of the checks that failed</returns>
static IEnumerable<FailedCheck> CheckInputTensorShape(
Model model, BrainParameters brainParameters, ISensor[] sensors,
int observableAttributeTotalSize)
{
var failedModelChecks = new List<FailedCheck>();
var tensorTester =
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, FailedCheck>>()
{
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)},
};
foreach (var mem in model.memories)
{
tensorTester[mem.input] = ((bp, tensor, scs, i) => null);
}
for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationSpec().Rank == 3)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens);
}
if (sens.GetObservationSpec().Rank == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);
}
if (sens.GetObservationSpec().Rank == 1)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens);
}
}
// If the model expects an input but it is not in this list
foreach (var tensor in model.GetInputTensors())
{
if (!tensorTester.ContainsKey(tensor.name))
{
failedModelChecks.Add(FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name
));
}
else
{
var tester = tensorTester[tensor.name];
var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize);
if (error != null)
{
failedModelChecks.Add(error);
}
}
}
return failedModelChecks;
}
/// <summary>
/// Checks that the shape of the Previous Vector Action input placeholder is the same in the
/// model and in the Brain Parameters.

53
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
{
// Loop through the sensors on a representative agent.
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
var visIndex = 0;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
{
// Loop through the sensors on a representative agent.
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
var visIndex = 0;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];
var rank = sensor.GetObservationSpec().Rank;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
{
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetObservationName(sensorIndex);
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
}
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var rank = shape.Length;
ObservationGenerator obsGen = null;
string obsGenName = null;

2
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


{
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationShape(), 0);
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0);
sensor.Write(m_ObservationWriter);
}
else

4
com.unity.ml-agents/Runtime/SensorHelper.cs


}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

19
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


/// <summary>
/// A Sensor that allows to observe a variable number of entities.
/// </summary>
public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor
public class BufferSensor : ISensor, IBuiltInSensor
{
private string m_Name;
private int m_MaxNumObs;

static DimensionProperty[] s_DimensionProperties = new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None
};
ObservationSpec m_ObservationSpec;
public BufferSensor(int maxNumberObs, int obsSize, string name)
{
m_Name = name;

m_CurrentNumObservables = 0;
m_ObservationSpec = ObservationSpec.VariableLength(m_MaxNumObs, m_ObsSize);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new int[] { m_MaxNumObs, m_ObsSize };
}
/// <inheritdoc/>
public DimensionProperty[] GetDimensionProperties()
{
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

32
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


/// <summary>
/// A sensor that wraps a Camera object to generate visual observations for an agent.
/// </summary>
public class CameraSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor
public class CameraSensor : ISensor, IBuiltInSensor
{
Camera m_Camera;
int m_Width;

int[] m_Shape;
private ObservationSpec m_ObservationSpec;
static DimensionProperty[] s_DimensionProperties = new DimensionProperty[] {
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None };
/// <summary>
/// The Camera used for rendering the sensor observations.

m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = GenerateShape(width, height, grayscale);
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_CompressionType = compression;
}

return m_Name;
}
/// <summary>
/// Accessor for the size of the sensor data. Will be h x w x 1 for grayscale and
/// h x w x 3 for color.
/// </summary>
/// <returns>Size of each of the three dimensions.</returns>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <summary>
/// Accessor for the dimension properties of a camera sensor. A camera sensor
/// Has translational equivariance along width and hight and no property along
/// the channels dimension.
/// </summary>
/// <returns></returns>
public DimensionProperty[] GetDimensionProperties()
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

70
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


}
/// <summary>
/// The Dimension property flags of the observations
/// </summary>
[System.Flags]
public enum DimensionProperty
{
/// <summary>
/// No properties specified.
/// </summary>
Unspecified = 0,
/// <summary>
/// No Property of the observation in that dimension. Observation can be processed with
/// fully connected networks.
/// </summary>
None = 1,
/// <summary>
/// Means it is suitable to do a convolution in this dimension.
/// </summary>
TranslationalEquivariance = 2,
/// <summary>
/// Means that there can be a variable number of observations in this dimension.
/// The observations are unordered.
/// </summary>
VariableSize = 4,
}
/// <summary>
/// The ObservationType enum of the Sensor.
/// </summary>
public enum ObservationType
{
/// <summary>
/// Collected observations are generic.
/// </summary>
Default = 0,
/// <summary>
/// Collected observations contain goal information.
/// </summary>
Goal = 1,
/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,
/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
}
/// <summary>
/// Returns the size of the observations that will be generated.
/// For example, a sensor that observes the velocity of a rigid body (in 3D) would return
/// new {3}. A sensor that returns an RGB image would return new [] {Height, Width, 3}
/// Returns a description of the observations that will be generated by the sensor.
/// See <see cref="ObservationSpec"/> for more details, and helper methods to create one.
/// <returns>Size of the observations that will be generated.</returns>
int[] GetObservationShape();
/// <returns></returns>
ObservationSpec GetObservationSpec();
/// <summary>
/// Write the observation data directly to the <see cref="ObservationWriter"/>.

/// <returns></returns>
public static int ObservationSize(this ISensor sensor)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
foreach (var dim in shape)
for (var i = 0; i < obsSpec.Rank; i++)
count *= dim;
count *= obsSpec.Shape[i];
}
return count;

13
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="observationSpec">ObservationSpec of the observation to be written</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset)
{
SetTarget(data, observationSpec.Shape, offset);
}
/// <summary>
/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
internal void SetTarget(IList<float> data, int[] shape, int offset)
internal void SetTarget(IList<float> data, InplaceArray<int> shape, int offset)
{
m_Data = data;
m_Offset = offset;

8
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


public class RayPerceptionSensor : ISensor, IBuiltInSensor
{
float[] m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
RayPerceptionInput m_RayPerceptionInput;

void SetNumObservations(int numObservations)
{
m_Shape = new[] { numObservations };
m_ObservationSpec = ObservationSpec.Vector(numObservations);
m_Observations = new float[numObservations];
}

public void Reset() { }
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


// Cached sensor names and shapes.
string m_SensorName;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
int m_NumFloats;
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size)
{

m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_Shape = new[] { size };
m_ObservationSpec = ObservationSpec.Vector(size);
m_NumFloats = size;
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

return m_Shape[0];
return m_NumFloats;
}
internal abstract void WriteReflectedField(ObservationWriter writer);

8
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


RenderTexture m_RenderTexture;
bool m_Grayscale;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var height = renderTexture != null ? renderTexture.height : 0;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = new[] { height, width, grayscale ? 1 : 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

21
com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs


{
internal class SensorShapeValidator
{
List<int[]> m_SensorShapes;
List<ObservationSpec> m_SensorShapes;
/// <summary>
/// Check that the List Sensors are the same shape as the previous ones.

{
if (m_SensorShapes == null)
{
m_SensorShapes = new List<int[]>(sensors.Count);
m_SensorShapes = new List<ObservationSpec>(sensors.Count);
m_SensorShapes.Add(sensor.GetObservationShape());
m_SensorShapes.Add(sensor.GetObservationSpec());
}
}
else

);
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
{
var cachedShape = m_SensorShapes[i];
var sensorShape = sensors[i].GetObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++)
{
Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes must match.");
}
var cachedSpec = m_SensorShapes[i];
var sensorSpec = sensors[i].GetObservationSpec();
Debug.AssertFormat(
cachedSpec.Shape == sensorSpec.Shape,
"Sensor shapes must match. {0} != {1}",
cachedSpec.Shape,
sensorSpec.Shape
);
}
}
}

48
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


int m_UnstackedObservationSize;
string m_Name;
int[] m_Shape;
int[] m_WrappedShape;
private ObservationSpec m_ObservationSpec;
private ObservationSpec m_WrappedSpec;
/// <summary>
/// Buffer of previous observations

m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
m_WrappedShape = wrapped.GetObservationShape();
m_Shape = new int[m_WrappedShape.Length];
m_WrappedSpec = wrapped.GetObservationSpec();
for (int d = 0; d < m_WrappedShape.Length; d++)
{
m_Shape[d] = m_WrappedShape[d];
}
// Set up the cached observation spec for the StackingSensor
var newShape = m_WrappedSpec.Shape;
m_Shape[m_Shape.Length - 1] *= numStackedObservations;
newShape[newShape.Length - 1] *= numStackedObservations;
m_ObservationSpec = new ObservationSpec(
newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType
);
// Initialize uncompressed buffer anyway in case python trainer does not
// support the compression mapping and has to fall back to uncompressed obs.

m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
}
if (m_Shape.Length != 1)
if (m_WrappedSpec.Rank != 1)
m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
var wrappedShape = m_WrappedSpec.Shape;
m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]);
}
}

// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0);
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0);
if (m_WrappedShape.Length == 1)
if (m_WrappedSpec.Rank == 1)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{

for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
for (var h = 0; h < m_WrappedShape[0]; h++)
for (var h = 0; h < m_WrappedSpec.Shape[0]; h++)
for (var w = 0; w < m_WrappedShape[1]; w++)
for (var w = 0; w < m_WrappedSpec.Shape[1]; w++)
for (var c = 0; c < m_WrappedShape[2]; c++)
for (var c = 0; c < m_WrappedSpec.Shape[2]; c++)
writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
writer[h, w, i * m_WrappedSpec.Shape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations;
numWritten = m_WrappedSpec.Shape[0] * m_WrappedSpec.Shape[1] * m_WrappedSpec.Shape[2] * m_NumStackedObservations;
}
return numWritten;

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

/// </summary>
internal byte[] CreateEmptyPNG()
{
int height = m_WrappedSensor.GetObservationShape()[0];
int width = m_WrappedSensor.GetObservationShape()[1];
var shape = m_WrappedSpec.Shape;
int height = shape[0];
int width = shape[1];
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
Color32[] resetColorArray = texture2D.GetPixels32();
Color32 black = new Color32(0, 0, 0, 0);

// wrapped sensor doesn't have one, use default mapping.
// Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
int[] wrappedMapping = null;
int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2];
int wrappedNumChannel = m_WrappedSpec.Shape[2];
var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
if (sparseChannelSensor != null)
{

10
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


// TODO use float[] instead
// TODO allow setting float[]
List<float> m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
/// <summary>

m_Observations = new List<float>(observationSize);
m_Name = name;
m_Shape = new[] { observationSize };
m_ObservationSpec = ObservationSpec.Vector(observationSize);
var expectedObservations = m_Shape[0];
var expectedObservations = m_ObservationSpec.Shape[0];
if (m_Observations.Count > expectedObservations)
{
// Too many observations, truncate

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

23
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Analytics;

class DummySensor : ISensor
{
public int[] Shape;
public ObservationSpec ObservationSpec;
public SensorCompressionType CompressionType;
internal DummySensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
}
public int Write(ObservationWriter writer)

foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var inplaceShape = InplaceArray<int>.FromList(shape);
dummySensor.Shape = shape;
if (shape.Length == 1)
{
dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]);
}
else if (shape.Length == 3)
{
dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]);
}
else
{
throw new ArgumentOutOfRangeException();
}
obsWriter.SetTarget(new float[128], shape, 0);
obsWriter.SetTarget(new float[128], inplaceShape, 0);
var caps = new UnityRLCapabilities
{

4
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs