浏览代码

Merge pull request #5157 from Unity-Technologies/v2-staging-rebase-2

V2 Staging back to main
/develop/lex-walker-model
GitHub 4 年前
当前提交
78851829
共有 143 个文件被更改,包括 2103 次插入836 次删除
  1. 3
      .gitignore
  2. 2
      DevProject/Assets/ML-Agents/Scripts/Tests/Editor/Editor.asmdef
  3. 1
      DevProject/Assets/ML-Agents/Scripts/Tests/Runtime/AcademyTest/AcademyStepperTest.cs
  4. 2
      DevProject/Assets/ML-Agents/Scripts/Tests/Runtime/Runtime.asmdef
  5. 2
      DevProject/Packages/manifest.json
  6. 2
      DevProject/Packages/packages-lock.json
  7. 18
      DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json
  8. 4
      DevProject/ProjectSettings/ProjectVersion.txt
  9. 6
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs
  10. 4
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
  11. 8
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  12. 6
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs
  13. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  14. 5
      Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta
  15. 8
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
  16. 2
      com.unity.ml-agents.extensions/Runtime/Input/InputActionActuator.cs
  17. 5
      com.unity.ml-agents.extensions/Runtime/Input/InputActuatorComponent.cs
  18. 69
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
  19. 6
      com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs
  20. 12
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  21. 30
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  22. 10
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  23. 12
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  24. 8
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs
  25. 6
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs
  26. 19
      com.unity.ml-agents/CHANGELOG.md
  27. 16
      com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
  28. 2
      com.unity.ml-agents/Runtime/Academy.cs
  29. 14
      com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs
  30. 29
      com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
  31. 4
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  32. 43
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  33. 2
      com.unity.ml-agents/Runtime/Actuators/IActuator.cs
  34. 24
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  35. 2
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
  36. 83
      com.unity.ml-agents/Runtime/Agent.cs
  37. 7
      com.unity.ml-agents/Runtime/Analytics/Events.cs
  38. 55
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  39. 44
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  40. 70
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  41. 472
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  42. 10
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  43. 93
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  44. 12
      com.unity.ml-agents/Runtime/Policies/BrainParameters.cs
  45. 2
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  46. 4
      com.unity.ml-agents/Runtime/SensorHelper.cs
  47. 19
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  48. 32
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  49. 70
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  50. 41
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  51. 8
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  52. 12
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  53. 8
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  54. 21
      com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
  55. 21
      com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
  56. 48
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  57. 23
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  58. 25
      com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs
  59. 45
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs
  60. 6
      com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
  61. 5
      com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
  62. 2
      com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs
  63. 2
      com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
  64. 23
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  65. 10
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  66. 6
      com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs
  67. 3
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  68. 13
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  69. 2
      com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
  70. 20
      com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs
  71. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
  72. 27
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  73. 30
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  74. 26
      docs/Learning-Environment-Design-Agents.md
  75. 50
      docs/Migrating.md
  76. 1
      docs/Readme.md
  77. 2
      ml-agents/mlagents/trainers/torch/distributions.py
  78. 61
      ml-agents/mlagents/trainers/torch/model_serialization.py
  79. 44
      ml-agents/mlagents/trainers/torch/networks.py
  80. 2
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
  81. 41
      com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs
  82. 77
      com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs
  83. 12
      com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs
  84. 7
      DevProject/Assets/ML-Agents/Scripts/Tests/Editor/Editor.asmdef.meta
  85. 8
      DevProject/Assets/ML-Agents/Scripts/Tests/Editor/MLAgentsSettings.meta
  86. 75
      com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs
  87. 11
      com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs.meta
  88. 198
      com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs
  89. 11
      com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs.meta
  90. 241
      com.unity.ml-agents/Runtime/InplaceArray.cs
  91. 3
      com.unity.ml-agents/Runtime/InplaceArray.cs.meta
  92. 41
      com.unity.ml-agents/Runtime/MLAgentsSettings.cs
  93. 11
      com.unity.ml-agents/Runtime/MLAgentsSettings.cs.meta
  94. 91
      com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs
  95. 11
      com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs.meta
  96. 140
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  97. 3
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta
  98. 8
      com.unity.ml-agents/Tests/Editor/Inference.meta

3
.gitignore


# Environemnt logfile
*Project.log
# Custom settings asset
*.settings.asset*
# Visual Studio 2015 cache directory
/Project/.vs/

2
DevProject/Assets/ML-Agents/Scripts/Tests/Editor/Editor.asmdef


{
"name": "Unity.ML-Agents.Performance.Tests",
"name": "Unity.ML-Agents.DevTests.Editor",
"references": [
"Unity.ML-Agents.Editor",
"Unity.ML-Agents",

1
DevProject/Assets/ML-Agents/Scripts/Tests/Runtime/AcademyTest/AcademyStepperTest.cs


[SetUp]
public void Setup()
{
Academy.Instance.Dispose();
SceneManager.LoadScene("ML-Agents/Scripts/Tests/Runtime/AcademyTest/AcademyStepperTestScene");
}

2
DevProject/Assets/ML-Agents/Scripts/Tests/Runtime/Runtime.asmdef


{
"name": "Runtime",
"name": "Unity.ML-Agents.DevTests.Runtime",
"references": [
"Unity.ML-Agents"
],

2
DevProject/Packages/manifest.json


"com.unity.package-manager-doctools": "1.7.0-preview",
"com.unity.package-validation-suite": "0.19.0-preview",
"com.unity.purchasing": "2.2.1",
"com.unity.test-framework": "1.1.20",
"com.unity.test-framework": "1.1.22",
"com.unity.test-framework.performance": "2.2.0-preview",
"com.unity.testtools.codecoverage": "1.0.0-pre.3",
"com.unity.textmeshpro": "2.0.1",

2
DevProject/Packages/packages-lock.json


"url": "https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-candidates"
},
"com.unity.test-framework": {
"version": "1.1.20",
"version": "1.1.22",
"depth": 0,
"source": "registry",
"dependencies": {

18
DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json


"m_Name": "Settings",
"m_Path": "ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json",
"m_Dictionary": {
"m_DictionaryValues": []
"m_DictionaryValues": [
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "Path",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "HistoryPath",
"value": "{\"m_Value\":\"{ProjectPath}\"}"
},
{
"type": "System.String, mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089",
"key": "IncludeAssemblies",
"value": "{\"m_Value\":\"Assembly-CSharp,Runtime,Unity.ML-Agents,Unity.ML-Agents.Extensions\"}"
}
]
}
}

4
DevProject/ProjectSettings/ProjectVersion.txt


m_EditorVersion: 2019.4.19f1
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec)
m_EditorVersion: 2019.4.20f1
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa)

6
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs


/// Creates a BasicActuator.
/// </summary>
/// <returns></returns>
#pragma warning disable 672
public override IActuator CreateActuator()
#pragma warning restore 672
public override IActuator[] CreateActuators()
return new BasicActuator(basicController);
return new IActuator[] { new BasicActuator(basicController) };
}
public override ActionSpec ActionSpec

4
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs


}
/// <inheritdoc/>
public override int[] GetObservationShape()
public override ObservationSpec GetObservationSpec()
return new[] { BasicController.k_Extents };
return ObservationSpec.Vector(BasicController.k_Extents);
}
/// <inheritdoc/>

8
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


if (positionX == 0)
{
actionMask.WriteMask(0, new[] { k_Left });
actionMask.SetActionEnabled(0, k_Left, false);
actionMask.WriteMask(0, new[] { k_Right });
actionMask.SetActionEnabled(0, k_Right, false);
actionMask.WriteMask(0, new[] { k_Down });
actionMask.SetActionEnabled(0, k_Down, false);
actionMask.WriteMask(0, new[] { k_Up });
actionMask.SetActionEnabled(0, k_Up, false);
}
}
}

6
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs


public class Match3ExampleActuatorComponent : Match3ActuatorComponent
{
/// <inheritdoc/>
#pragma warning disable 672
public override IActuator CreateActuator()
#pragma warning restore 672
public override IActuator[] CreateActuators()
return new Match3ExampleActuator(board, ForceHeuristic, agent, ActuatorName, seed);
return new IActuator[] { new Match3ExampleActuator(board, ForceHeuristic, agent, ActuatorName, seed) };
}
}
}

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


public abstract void WriteObservation(float[] output);
/// <inheritdoc/>
public abstract int[] GetObservationShape();
public abstract ObservationSpec GetObservationSpec();
/// <inheritdoc/>
public abstract string GetName();

5
Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta


fileFormatVersion: 2
guid: 8cd4584c2f2cb4c5fb51675d364e10ec
ScriptedImporter:
internalIDToNameTable: []
fileIDToRecycleName:
11400000: main obj
11400002: model data
serializedVersion: 2
userData:
assetBundleName:
assetBundleVariant:

8
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs


{
Texture2D m_Texture;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var width = texture.width;
var height = texture.height;
m_Name = name;
m_Shape = new[] { height, width, 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

2
com.unity.ml-agents.extensions/Runtime/Input/InputActionActuator.cs


/// <see cref="Agent"/>'s <see cref="BehaviorParameters"/> indicate that the Agent is running in Heuristic Mode,
/// this Actuator will write actions from the <see cref="InputSystem"/> to the <see cref="ActionBuffers"/> object.
/// </summary>
public class InputActionActuator : IActuator, IHeuristicProvider, IBuiltInActuator
public class InputActionActuator : IActuator, IBuiltInActuator
{
readonly BehaviorParameters m_BehaviorParameters;
readonly InputAction m_Action;

5
com.unity.ml-agents.extensions/Runtime/Input/InputActuatorComponent.cs


return inputControlScheme;
}
#pragma warning disable 672
/// <inheritdoc cref="ActuatorComponent.CreateActuator"/>
public override IActuator CreateActuator() { return null; }
#pragma warning restore 672
/// <summary>
///
/// </summary>

69
com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs


/// <param name="agent"></param>
/// <param name="name"></param>
public Match3Actuator(AbstractBoard board,
bool forceHeuristic,
int seed,
Agent agent,
string name)
bool forceHeuristic,
int seed,
Agent agent,
string name)
{
m_Board = board;
m_Rows = board.Rows;

/// <inheritdoc/>
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
const int branch = 0;
bool foundValidMove = false;
actionMask.WriteMask(0, InvalidMoveIndices());
}
}
var numMoves = m_Board.NumMoves();
/// <inheritdoc/>
public string Name { get; }
var currentMove = Move.FromMoveIndex(0, m_Board.Rows, m_Board.Columns);
for (var i = 0; i < numMoves; i++)
{
if (m_Board.IsMoveValid(currentMove))
{
foundValidMove = true;
}
else
{
actionMask.SetActionEnabled(branch, i, false);
}
currentMove.Next(m_Board.Rows, m_Board.Columns);
}
/// <inheritdoc/>
public void ResetData()
{
}
/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}
IEnumerable<int> InvalidMoveIndices()
{
var numValidMoves = m_Board.NumMoves();
foreach (var move in m_Board.InvalidMoves())
{
numValidMoves--;
if (numValidMoves == 0)
if (!foundValidMove)
{
// If all the moves are invalid and we mask all the actions out, this will cause an assert
// later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,

"an invalid move will be passed to AbstractBoard.MakeMove()."
);
}
// This means the last move won't be returned as an invalid index.
yield break;
actionMask.SetActionEnabled(branch, numMoves - 1, true);
yield return move.MoveIndex;
/// <inheritdoc/>
public string Name { get; }
/// <inheritdoc/>
public void ResetData()
{
}
/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}
public void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;

var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;

6
com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs


public bool ForceHeuristic;
/// <inheritdoc/>
#pragma warning disable 672
public override IActuator CreateActuator()
#pragma warning restore 672
public override IActuator[] CreateActuators()
return new Match3Actuator(board, ForceHeuristic, seed, agent, ActuatorName);
return new IActuator[] { new Match3Actuator(board, ForceHeuristic, seed, agent, ActuatorName) };
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

m_NumSpecialTypes = board.NumSpecialTypes;
m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

30
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


protected bool Initialized = false;
/// <summary>
/// Array holding the dimensions of the resulting tensor
/// Cached ObservationSpec
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
//
// Debug Parameters

// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();

CellActivity[i] = DebugDefaultColor;
}
}
}
/// <summary>Gets the shape of the grid observation</summary>
/// <returns>integer array shape of the grid observation</returns>
public int[] GetFloatObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}
/// <inheritdoc/>

/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
public ObservationSpec GetObservationSpec()
{
// Lazy update
var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}
/// <inheritdoc/>
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
/// <inheritdoc/>

10
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_SensorName;
PoseExtractor m_PoseExtractor;

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#if UNITY_2020_1_OR_NEWER

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

8
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

6
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
}
}

19
com.unity.ml-agents/CHANGELOG.md


## [Unreleased]
### Major Changes
#### com.unity.ml-agents (C#)
- Several breaking interface changes were made. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details.
- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart.
- The interface for disabling discrete actions in `IDiscreteActionMask` has changed.
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. (#5060)
- IActuator now implements IHeuristicProvider. (#5110)
- `ISensor.GetObservationShape()` was removed, and `GetObservationSpec()` was added. (#5127)
- Make com.unity.modules.unityanalytics an optional dependency. (#5109)
- The `.onnx` models input names have changed. All input placeholders will now use the prefix `obs_` removing the distinction between visual and vector observations. Models created with this version will not be usable with previous versions of the package (#5080)
- The `.onnx` models discrete action output now contains the discrete actions values and not the logits. Models created with this version will not be usable with previous versions of the package (#5080)
- Added ML-Agents package settings. (#5027)
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### com.unity.ml-agents (C#)
## [1.9.0-preview] - 2021-03-17
### Major Changes

#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Updated com.unity.barracuda to 1.3.2-preview. (#5084)
- Added 3D Ball to the `com.unity.ml-agents` samples. (#5077)
- Make com.unity.modules.unityanalytics an optional dependency. (#5109)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The `encoding_size` setting for RewardSignals has been deprecated. Please use `network_settings` instead. (#4982)
- Sensor names are now passed through to `ObservationSpec.name`. (#5036)

16
com.unity.ml-agents/Editor/BehaviorParametersEditor.cs


using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using CheckTypeEnum = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck.CheckTypeEnum;
namespace Unity.MLAgents.Editor
{

{
if (check != null)
{
EditorGUILayout.HelpBox(check, MessageType.Warning);
switch (check.CheckType)
{
case CheckTypeEnum.Info:
EditorGUILayout.HelpBox(check.Message, MessageType.Info);
break;
case CheckTypeEnum.Warning:
EditorGUILayout.HelpBox(check.Message, MessageType.Warning);
break;
case CheckTypeEnum.Error:
EditorGUILayout.HelpBox(check.Message, MessageType.Error);
break;
default:
break;
}
}
}
}

2
com.unity.ml-agents/Runtime/Academy.cs


// No arg passed, or malformed port number.
#if UNITY_EDITOR
// Try connecting on the default editor port
return k_EditorTrainingPort;
return MLAgentsSettingsManager.Settings.ConnectTrainer ? MLAgentsSettingsManager.Settings.EditorPort : -1;
#else
// This is an executable, so we don't try to connect.
return -1;

14
com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs


public abstract class ActuatorComponent : MonoBehaviour
{
/// <summary>
/// Create the IActuator. This is called by the Agent when it is initialized.
/// </summary>
/// <returns>Created IActuator object.</returns>
[Obsolete("Use CreateActuators instead.")]
public abstract IActuator CreateActuator();
/// <summary>
public virtual IActuator[] CreateActuators()
{
#pragma warning disable 618
return new[] { CreateActuator() };
#pragma warning restore 618
}
public abstract IActuator[] CreateActuators();
/// <summary>
/// The specification of the possible actions for this ActuatorComponent.

29
com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs


}
/// <inheritdoc/>
public void WriteMask(int branch, IEnumerable<int> actionIndices)
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
// Perform the masking
foreach (var actionIndex in actionIndices)
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
}
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled;
}
void LazyInitialize()

}
}
/// <inheritdoc/>
public bool[] GetMask()
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
{
#if DEBUG
if (m_CurrentMask != null)

/// <summary>
/// Resets the current mask for an agent.
/// </summary>
public void ResetMask()
internal void ResetMask()
{
if (m_CurrentMask != null)
{

4
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


discreteStart,
numDiscreteActions);
}
var heuristic = actuator as IHeuristicProvider;
heuristic?.Heuristic(new ActionBuffers(continuousActions, discreteActions));
actuator.Heuristic(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}

43
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
/// <summary>
/// Packs the continuous and discrete actions into one float array. The array passed into this method
/// must have a Length that is greater than or equal to the sum of the Lengths of
/// <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/>.
/// </summary>
/// <param name="destination">A float array to pack actions into whose length is greater than or
/// equal to the addition of the Lengths of this objects <see cref="ContinuousActions"/> and
/// <see cref="DiscreteActions"/> segments.</param>
[Obsolete("PackActions has been deprecated.")]
public void PackActions(in float[] destination)
{
Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length,
$"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" +
$"{nameof(destination)}.Length: {destination.Length}\n" +
$"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}");
var start = 0;
if (ContinuousActions.Length > 0)
{
Array.Copy(ContinuousActions.Array,
ContinuousActions.Offset,
destination,
start,
ContinuousActions.Length);
start = ContinuousActions.Length;
}
if (start >= destination.Length)
{
return;
}
if (DiscreteActions.Length > 0)
{
Array.Copy(DiscreteActions.Array,
DiscreteActions.Offset,
destination,
start,
DiscreteActions.Length);
}
}
}
/// <summary>

/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///

2
com.unity.ml-agents/Runtime/Actuators/IActuator.cs


/// <summary>
/// Abstraction that facilitates the execution of actions.
/// </summary>
public interface IActuator : IActionReceiver
public interface IActuator : IActionReceiver, IHeuristicProvider
{
/// <summary>
/// The specification of the actions for this IActuator.

24
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


public interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.
/// Set whether or not the action index for the given branch is allowed.
/// When used, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndices correspond
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond
/// to the action options the agent will be unable to perform.
///
/// See [Agents - Actions] for more information on masking actions.

/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndices">The indices of the masked actions.</param>
void WriteMask(int branch, IEnumerable<int> actionIndices);
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
bool[] GetMask();
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
void ResetMask();
/// <param name="actionIndex">Index of the action</param>
/// <param name="isEnabled">Whether the action is allowed or now.</param>
void SetActionEnabled(int branch, int actionIndex, bool isEnabled);
}
}

2
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs


/// <summary>
/// IActuator implementation that forwards calls to an <see cref="IActionReceiver"/> and an <see cref="IHeuristicProvider"/>.
/// </summary>
internal class VectorActuator : IActuator, IHeuristicProvider, IBuiltInActuator
internal class VectorActuator : IActuator, IBuiltInActuator
{
IActionReceiver m_ActionReceiver;
IHeuristicProvider m_HeuristicProvider;

83
com.unity.ml-agents/Runtime/Agent.cs


/// Whether or not the Agent has been initialized already
bool m_Initialized;
/// Keeps track of the actions that are masked at each step.
DiscreteActionMasker m_ActionMasker;
/// <summary>
/// Set of DemonstrationWriters that the Agent will write its step information to.
/// If you use a DemonstrationRecorder component, this will automatically register its DemonstrationWriter.

/// with the current behavior of Agent.
/// </summary>
IActuator m_VectorActuator;
/// <summary>
/// This is used to avoid allocation of a float array every frame if users are still using the old
/// OnActionReceived method.
/// </summary>
float[] m_LegacyActionCache;
/// <summary>
/// This is used to avoid allocation of a float array during legacy calls to Heuristic.
/// </summary>
float[] m_LegacyHeuristicCache;
/// Currect MultiAgentGroup ID. Default to 0 (meaning no group)
int m_GroupId;

/// <seealso cref="IActionReceiver.OnActionReceived"/>
public virtual void Heuristic(in ActionBuffers actionsOut)
{
// Disable deprecation warnings so we can call the legacy overload.
#pragma warning disable CS0618
// The default implementation of Heuristic calls the
// obsolete version for backward compatibility
switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType)
{
case SpaceType.Continuous:
Heuristic(m_LegacyHeuristicCache);
Array.Copy(m_LegacyHeuristicCache, actionsOut.ContinuousActions.Array, m_LegacyActionCache.Length);
actionsOut.DiscreteActions.Clear();
break;
case SpaceType.Discrete:
Heuristic(m_LegacyHeuristicCache);
var discreteActionSegment = actionsOut.DiscreteActions;
for (var i = 0; i < actionsOut.DiscreteActions.Length; i++)
{
discreteActionSegment[i] = (int)m_LegacyHeuristicCache[i];
}
actionsOut.ContinuousActions.Clear();
break;
}
#pragma warning restore CS0618
Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions.");
}
/// <summary>

var param = m_PolicyFactory.BrainParameters;
m_VectorActuator = new AgentVectorActuator(this, this, param.ActionSpec);
m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1);
m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions()];
m_LegacyHeuristicCache = new float[m_VectorActuator.TotalNumberOfActions()];
m_ActuatorManager.Add(m_VectorActuator);

/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask(int, IEnumerable{int})"/>.
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///

public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
if (m_ActionMasker == null)
{
m_ActionMasker = new DiscreteActionMasker(actionMask);
}
// Disable deprecation warnings so we can call the legacy overload.
#pragma warning disable CS0618
CollectDiscreteActionMasks(m_ActionMasker);
#pragma warning restore CS0618
}
public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { }
/// <summary>
/// Implement `OnActionReceived()` to specify agent behavior at every step, based

/// <param name="actions">
/// Struct containing the buffers of actions to be executed at this step.
/// </param>
public virtual void OnActionReceived(ActionBuffers actions)
{
var actionSpec = m_PolicyFactory.BrainParameters.ActionSpec;
// For continuous and discrete actions together, we don't need to fall back to the legacy method
if (actionSpec.NumContinuousActions > 0 && actionSpec.NumDiscreteActions > 0)
{
// Nothing implemented.
return;
}
if (!actions.ContinuousActions.IsEmpty())
{
Array.Copy(actions.ContinuousActions.Array,
m_LegacyActionCache,
actionSpec.NumContinuousActions);
}
else
{
for (var i = 0; i < m_LegacyActionCache.Length; i++)
{
m_LegacyActionCache[i] = (float)actions.DiscreteActions[i];
}
}
// Disable deprecation warnings so we can call the legacy overload.
#pragma warning disable CS0618
OnActionReceived(m_LegacyActionCache);
#pragma warning restore CS0618
}
public virtual void OnActionReceived(ActionBuffers actions) { }
/// <summary>
/// Implement `OnEpisodeBegin()` to set up an Agent instance at the beginning

7
com.unity.ml-agents/Runtime/Analytics/Events.cs


public static EventObservationSpec FromSensor(ISensor sensor)
{
var shape = sensor.GetObservationShape();
var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
var dimProps = obsSpec.DimensionProperties;
dimInfos[i].Flags = dimProps != null ? (int)dimProps[i] : 0;
dimInfos[i].Flags = (int)dimProps[i];
}
var builtInSensorType =

55
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


/// <returns></returns>
public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
var shape = obsSpec.Shape;
ObservationProto observationProto = null;
var compressionType = sensor.GetCompressionType();
// Check capabilities if we need to concatenate PNGs

floatDataProto.Data.Add(0.0f);
}
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0);
sensor.Write(observationWriter);
observationProto = new ObservationProto

observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
}
}
// Add the dimension properties if any to the observationProto
var dimensionPropertySensor = sensor as IDimensionPropertiesSensor;
if (dimensionPropertySensor != null)
// Add the dimension properties to the observationProto
var dimensionProperties = obsSpec.DimensionProperties;
for (int i = 0; i < dimensionProperties.Length; i++)
{
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
}
// Checking trainer compatibility with variable length observations
if (dimensionProperties == new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None))
var dimensionProperties = dimensionPropertySensor.GetDimensionProperties();
for (int i = 0; i < dimensionProperties.Length; i++)
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
// Checking trainer compatibility with variable length observations
if (dimensionProperties.Length == 2)
{
if (dimensionProperties[0] == DimensionProperty.VariableSize &&
dimensionProperties[1] == DimensionProperty.None)
{
var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
if (!trainerCanHandleVarLenObs)
{
throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
}
}
}
}
for (var i = 0; i < shape.Length; i++)
{
observationProto.Shape.Add(shape[i]);
observationProto.Shape.AddRange(shape);
var sensorName = sensor.GetName();
if (!string.IsNullOrEmpty(sensorName))
{

// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;
if (typeSensor != null)
{
observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
}
else
{
observationProto.ObservationType = ObservationTypeProto.Default;
}
observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType;
return observationProto;
}

44
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


}
/// <summary>
/// The Applier for the Discrete Action output tensor.
/// </summary>
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
{
readonly ActionSpec m_ActionSpec;
public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
{
m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
for (var i = 0; i < actionIds.Count; i++)
{
var agentId = actionIds[i];
if (lastActions.ContainsKey(agentId))
{
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
{
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
}
var discreteBuffer = actionBuffer.DiscreteActions;
for (var j = 0; j < actionSize; j++)
{
discreteBuffer[j] = (int)tensorProxy.data[agentIndex, j];
}
}
agentIndex++;
}
}
}
/// <summary>
internal class DiscreteActionOutputApplier : TensorApplier.IApplier
internal class LegacyDiscreteActionOutputApplier : TensorApplier.IApplier
{
readonly int[] m_ActionSize;
readonly Multinomial m_Multinomial;

public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
public LegacyDiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
{
m_ActionSize = actionSpec.BranchSizes;
m_Multinomial = new Multinomial(seed);

70
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


using System.Collections.Generic;
using System.Linq;
using Unity.Barracuda;
using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck;
namespace Unity.MLAgents.Inference
{

names.Sort();
return names.ToArray();
}
/// <summary>
/// Get the version of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>The api version of the model</returns>
public static int GetVersion(this Model model)
{
return (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
}
/// <summary>

else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
(int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0;
(int)model.DiscreteOutputSize() > 0;
/// This method gets the tensor representing the list of branch size and returns the
/// sum of all the elements in the Tensor.
/// - In version 1.X this tensor contains a single number, the sum of all branch
/// size values.
/// - In version 2.X this tensor contains a 1D Tensor with each element corresponding
/// to a branch size.
/// Since this method does the sum of all elements in the tensor, the output
/// will be the same on both 1.X and 2.X.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.

else
{
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
return discreteOutputShape == null ? 0 : (int)discreteOutputShape[0];
if (discreteOutputShape == null)
{
return 0;
}
else
{
int result = 0;
for (int i = 0; i < discreteOutputShape.length; i++)
{
result += (int)discreteOutputShape[i];
}
return result;
}
}
}

/// <param name="failedModelChecks">Output list of failure messages</param>
///
/// <returns>True if the model contains all the expected tensors.</returns>
public static bool CheckExpectedTensors(this Model model, List<string> failedModelChecks)
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
);
return false;
}

{
failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
);
return false;
}

!model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add("The model does not contain any Action Output Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Node.")
);
return false;
}

if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
{
failedModelChecks.Add("The model does not contain any Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Shape Node.")
);
failedModelChecks.Add($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was not found in the model file. " +
"This is only required for model that uses a deprecated model format.");
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " +
"not found in the model file. " +
"This is only required for model that uses a deprecated model format.")
);
return false;
}
}

model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
{
failedModelChecks.Add("The model uses continuous action but does not contain Continuous Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
);
failedModelChecks.Add("The model uses discrete action but does not contain Discrete Action Output Shape Node.");
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
);
return false;
}
}

472
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


/// </summary>
internal class BarracudaModelParamLoader
{
const long k_ApiVersion = 2;
internal enum ModelApiVersion
{
MLAgents1_0 = 2,
MLAgents2_0 = 3,
MinSupportedVersion = MLAgents1_0,
MaxSupportedVersion = MLAgents2_0
}
internal class FailedCheck
{
public enum CheckTypeEnum
{
Info = 0,
Warning = 1,
Error = 2
}
public CheckTypeEnum CheckType;
public string Message;
public static FailedCheck Info(string message)
{
return new FailedCheck { CheckType = CheckTypeEnum.Info, Message = message };
}
public static FailedCheck Warning(string message)
{
return new FailedCheck { CheckType = CheckTypeEnum.Warning, Message = message };
}
public static FailedCheck Error(string message)
{
return new FailedCheck { CheckType = CheckTypeEnum.Error, Message = message };
}
}
/// <summary>
/// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks

/// <param name="actuatorComponents">Attached actuator components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
/// <returns>The list the error messages of the checks that failed</returns>
public static IEnumerable<string> CheckModel(Model model, BrainParameters brainParameters,
ISensor[] sensors, ActuatorComponent[] actuatorComponents,
/// <returns>A IEnumerable of the checks that failed</returns>
public static IEnumerable<FailedCheck> CheckModel(
Model model,
BrainParameters brainParameters,
ISensor[] sensors,
ActuatorComponent[] actuatorComponents,
BehaviorType behaviorType = BehaviorType.Default)
BehaviorType behaviorType = BehaviorType.Default
)
List<string> failedModelChecks = new List<string>();
List<FailedCheck> failedModelChecks = new List<FailedCheck>();
if (model == null)
{
var errorMsg = "There is no model for this Brain; cannot run inference. ";

{
errorMsg += "(But can still train)";
}
failedModelChecks.Add(errorMsg);
failedModelChecks.Add(FailedCheck.Info(errorMsg));
return failedModelChecks;
}

return failedModelChecks;
}
var modelApiVersion = (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
if (modelApiVersion == -1)
{
failedModelChecks.Add(
"Model was not trained using the right version of ML-Agents. " +
"Cannot use this model.");
return failedModelChecks;
}
if (modelApiVersion != k_ApiVersion)
var modelApiVersion = model.GetVersion();
if (modelApiVersion < (int)ModelApiVersion.MinSupportedVersion || modelApiVersion > (int)ModelApiVersion.MaxSupportedVersion)
$"Version of the trainer the model was trained with ({modelApiVersion}) " +
$"is not compatible with the Brain's version ({k_ApiVersion}).");
FailedCheck.Warning($"Version of the trainer the model was trained with ({modelApiVersion}) " +
$"is not compatible with the current range of supported versions: " +
$"({(int)ModelApiVersion.MinSupportedVersion} to {(int)ModelApiVersion.MaxSupportedVersion}).")
);
return failedModelChecks;
}

failedModelChecks.Add($"Missing node in the model provided : {TensorNames.MemorySize}");
failedModelChecks.Add(FailedCheck.Warning($"Missing node in the model provided : {TensorNames.MemorySize}"
));
if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0)
{
failedModelChecks.AddRange(
CheckInputTensorPresenceLegacy(model, brainParameters, memorySize, sensors)
);
failedModelChecks.AddRange(
CheckInputTensorShapeLegacy(model, brainParameters, sensors, observableAttributeTotalSize)
);
}
else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
failedModelChecks.AddRange(
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
);
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
);
}
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
CheckOutputTensorShape(model, brainParameters, actuatorComponents)
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
);
failedModelChecks.AddRange(
CheckOutputTensorShape(model, brainParameters, actuatorComponents)
);
/// present in the BrainParameters.
/// present in the BrainParameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters

/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed input presence checks.
/// A IEnumerable of the checks that failed
static IEnumerable<string> CheckInputTensorPresence(
static IEnumerable<FailedCheck> CheckInputTensorPresenceLegacy(
Model model,
BrainParameters brainParameters,
int memory,

var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
var tensorsNames = model.GetInputNames();
// If there is no Vector Observation Input but the Brain Parameters expect one.

failedModelChecks.Add(
"The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0.");
FailedCheck.Warning("The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0.")
);
}
// If there are not enough Visual Observation Input compared to what the

{
var sensor = sensors[sensorIndex];
if (sensor.GetObservationShape().Length == 3)
if (sensor.GetObservationSpec().Shape.Length == 3)
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
FailedCheck.Warning("The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).")
);
if (sensor.GetObservationShape().Length == 2)
if (sensor.GetObservationSpec().Shape.Length == 2)
"The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).");
FailedCheck.Warning("The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).")
);
}
}

if (expectedVisualObs > visObsIndex)
{
failedModelChecks.Add(
$"The model expects {expectedVisualObs} visual inputs," +
$" but only found {visObsIndex} visual sensors."
);
FailedCheck.Warning($"The model expects {expectedVisualObs} visual inputs," +
$" but only found {visObsIndex} visual sensors.")
);
}
// If the model has a non-negative memory size but requires a recurrent input

!tensorsNames.Any(x => x.EndsWith("_c")))
{
failedModelChecks.Add(
"The model does not contain a Recurrent Input Node but has memory_size.");
FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.")
);
}
}

if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
failedModelChecks.Add(
"The model does not contain an Action Mask but is using Discrete Control.");
FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.")
);
}
}
return failedModelChecks;
}
/// <summary>
/// Generates failed checks that correspond to inputs expected by the model that are not
/// present in the BrainParameters.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="memory">
/// The memory size that the model is expecting.
/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <returns>
/// A IEnumerable of the checks that failed
/// </returns>
static IEnumerable<FailedCheck> CheckInputTensorPresence(
Model model,
BrainParameters brainParameters,
int memory,
ISensor[] sensors
)
{
var failedModelChecks = new List<FailedCheck>();
var tensorsNames = model.GetInputNames();
for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
if (!tensorsNames.Contains(
TensorNames.GetObservationName(sensorIndex)))
{
var sensor = sensors[sensorIndex];
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).")
);
}
}
// If the model has a non-negative memory size but requires a recurrent input
if (memory > 0)
{
if (!tensorsNames.Any(x => x.EndsWith("_h")) ||
!tensorsNames.Any(x => x.EndsWith("_c")))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.")
);
}
}
// If the model uses discrete control but does not have an input for action masks
if (model.HasDiscreteOutputs())
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.")
);
}
}
return failedModelChecks;

/// </param>
/// <param name="memory">The memory size that the model is expecting/</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed output presence checks.
/// A IEnumerable of the checks that failed
static IEnumerable<string> CheckOutputTensorPresence(Model model, int memory)
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory)
var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)

!memOutputs.Any(x => x.EndsWith("_c")))
{
failedModelChecks.Add(
"The model does not contain a Recurrent Output Node but has memory_size.");
FailedCheck.Warning("The model does not contain a Recurrent Output Node but has memory_size.")
);
}
}
return failedModelChecks;

/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVisualObsShape(
static FailedCheck CheckVisualObsShape(
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];

if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
{
return $"The visual Observation of the model does not match. " +
return FailedCheck.Warning($"The visual Observation of the model does not match. " +
$"was expecting [?x{widthT}x{heightT}x{pixelT}].";
$"was expecting [?x{widthT}x{heightT}x{pixelT}] for the {sensor.GetName()} Sensor."
);
}
return null;
}

/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckRankTwoObsShape(
static FailedCheck CheckRankTwoObsShape(
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var dim3T = tensorProxy.Height;
return $"An Observation of the model does not match. " +
var proxyDimStr = $"[?x{dim1T}x{dim2T}]";
if (dim3T > 1)
{
proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]";
}
return FailedCheck.Warning($"An Observation of the model does not match. " +
$"was expecting [?x{dim1T}x{dim2T}].";
$"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
);
}
return null;
}
/// <summary>
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensor">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckRankOneObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationSpec().Shape;
var dim1Bp = shape[0];
var dim1T = tensorProxy.Channels;
var dim2T = tensorProxy.Width;
var dim3T = tensorProxy.Height;
if ((dim1Bp != dim1T))
{
var proxyDimStr = $"[?x{dim1T}]";
if (dim2T > 1)
{
proxyDimStr = $"[?x{dim1T}x{dim2T}]";
}
if (dim3T > 1)
{
proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]";
}
return FailedCheck.Warning($"An Observation of the model does not match. " +
$"Received TensorProxy of shape [?x{dim1Bp}] but " +
$"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
);
}
return null;
}

/// the model and the BrainParameters.
/// the model and the BrainParameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters

/// </param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>The list the error messages of the checks that failed</returns>
static IEnumerable<string> CheckInputTensorShape(
/// <returns>A IEnumerable of the checks that failed</returns>
static IEnumerable<FailedCheck> CheckInputTensorShapeLegacy(
var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, string>>()
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, FailedCheck>>()
{TensorNames.VectorObservationPlaceholder, CheckVectorObsShape},
{TensorNames.VectorObservationPlaceholder, CheckVectorObsShapeLegacy},
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},

for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
if (sens.GetObservationSpec().Shape.Length == 3)
{
tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] =

if (sens.GetObservationShape().Length == 2)
if (sens.GetObservationSpec().Shape.Length == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);

if (!tensor.name.Contains("visual_observation"))
{
failedModelChecks.Add(
"Model requires an unknown input named : " + tensor.name);
FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name)
);
}
}
else

/// <summary>
/// Checks that the shape of the Vector Observation input placeholder is the same in the
/// model and in the Brain Parameters.
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine

/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVectorObsShape(
static FailedCheck CheckVectorObsShapeLegacy(
BrainParameters brainParameters, TensorProxy tensorProxy, ISensor[] sensors,
int observableAttributeTotalSize)
{

var totalVectorSensorSize = 0;
foreach (var sens in sensors)
{
if ((sens.GetObservationShape().Length == 1))
if ((sens.GetObservationSpec().Shape.Length == 1))
totalVectorSensorSize += sens.GetObservationShape()[0];
totalVectorSensorSize += sens.GetObservationSpec().Shape[0];
}
}

foreach (var sensorComp in sensors)
{
if (sensorComp.GetObservationShape().Length == 1)
if (sensorComp.GetObservationSpec().Shape.Length == 1)
var vecSize = sensorComp.GetObservationShape()[0];
var vecSize = sensorComp.GetObservationSpec().Shape[0];
if (sensorSizes.Length == 0)
{
sensorSizes = $"[{vecSize}";

}
sensorSizes += "]";
return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
return FailedCheck.Warning(
$"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
$"Sensor sizes: {sensorSizes}.";
$"Sensor sizes: {sensorSizes}."
);
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>A IEnumerable of the checks that failed</returns>
static IEnumerable<FailedCheck> CheckInputTensorShape(
Model model, BrainParameters brainParameters, ISensor[] sensors,
int observableAttributeTotalSize)
{
var failedModelChecks = new List<FailedCheck>();
var tensorTester =
new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, FailedCheck>>()
{
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)},
};
foreach (var mem in model.memories)
{
tensorTester[mem.input] = ((bp, tensor, scs, i) => null);
}
for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationSpec().Rank == 3)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens);
}
if (sens.GetObservationSpec().Rank == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);
}
if (sens.GetObservationSpec().Rank == 1)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens);
}
}
// If the model expects an input but it is not in this list
foreach (var tensor in model.GetInputTensors())
{
if (!tensorTester.ContainsKey(tensor.name))
{
failedModelChecks.Add(FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name
));
}
else
{
var tester = tensorTester[tensor.name];
var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize);
if (error != null)
{
failedModelChecks.Add(error);
}
}
}
return failedModelChecks;
}
/// <summary>
/// Checks that the shape of the Previous Vector Action input placeholder is the same in the
/// model and in the Brain Parameters.

/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes (unused).</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckPreviousActionShape(
static FailedCheck CheckPreviousActionShape(
BrainParameters brainParameters, TensorProxy tensorProxy,
ISensor[] sensors, int observableAttributeTotalSize)
{

{
return "Previous Action Size of the model does not match. " +
$"Received {numberActionsBp} but was expecting {numberActionsT}.";
return FailedCheck.Warning("Previous Action Size of the model does not match. " +
$"Received {numberActionsBp} but was expecting {numberActionsT}."
);
}
return null;
}

/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <returns>
/// A IEnumerable of string corresponding to the incompatible shapes between model
/// A IEnumerable of error messages corresponding to the incompatible shapes between model
static IEnumerable<string> CheckOutputTensorShape(
static IEnumerable<FailedCheck> CheckOutputTensorShape(
var failedModelChecks = new List<string>();
var failedModelChecks = new List<FailedCheck>();
// If the model expects an output but it is not in this list
var modelContinuousActionSize = model.ContinuousOutputSize();

failedModelChecks.Add(continuousError);
}
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
var discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
FailedCheck discreteError = null;
var modelApiVersion = model.GetVersion();
if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0)
{
var modelSumDiscreteBranchSizes = model.DiscreteOutputSize();
discreteError = CheckDiscreteActionOutputShapeLegacy(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes);
}
if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
var modelDiscreteBranches = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelDiscreteBranches);
}
if (discreteError != null)
{
failedModelChecks.Add(discreteError);

/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelDiscreteBranches"> The Tensor of branch sizes.
/// </param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckDiscreteActionOutputShape(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, Tensor modelDiscreteBranches)
{
var discreteActionBranches = brainParameters.ActionSpec.BranchSizes.ToList();
foreach (var actuatorComponent in actuatorComponents)
{
var actionSpec = actuatorComponent.ActionSpec;
discreteActionBranches.AddRange(actionSpec.BranchSizes);
}
int modelDiscreteBranchesLength = modelDiscreteBranches?.length ?? 0;
if (modelDiscreteBranchesLength != discreteActionBranches.Count)
{
return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{discreteActionBranches.Count} branches but the model contains {modelDiscreteBranchesLength}."
);
}
for (int i = 0; i < modelDiscreteBranchesLength; i++)
{
if (modelDiscreteBranches[i] != discreteActionBranches[i])
{
return FailedCheck.Warning($"The number of Discrete Actions of branch {i} does not match. " +
$"Was expecting {discreteActionBranches[i]} but the model contains {modelDiscreteBranches[i]} "
);
}
}
return null;
}
/// <summary>
/// Checks that the shape of the discrete action output is the same in the
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelSumDiscreteBranchSizes">
/// The size of the discrete action output that is expected by the model.
/// </param>

/// </returns>
static string CheckDiscreteActionOutputShape(
static FailedCheck CheckDiscreteActionOutputShapeLegacy(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelSumDiscreteBranchSizes)
{
// TODO: check each branch size instead of sum of branch sizes

if (modelSumDiscreteBranchSizes != sumOfDiscreteBranchSizes)
{
return "Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}.";
return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}."
);
}
return null;
}

/// </param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckContinuousActionOutputShape(
static FailedCheck CheckContinuousActionOutputShape(
BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelContinuousActionSize)
{
var numContinuousActions = brainParameters.ActionSpec.NumContinuousActions;

if (modelContinuousActionSize != numContinuousActions)
{
return "Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " +
$"{numContinuousActions} but the model contains {modelContinuousActionSize}.";
return FailedCheck.Warning(
"Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " +
$"{numContinuousActions} but the model contains {modelContinuousActionSize}."
);
}
return null;
}

10
com.unity.ml-agents/Runtime/Inference/TensorApplier.cs


if (actionSpec.NumDiscreteActions > 0)
{
var tensorName = model.DiscreteOutputName();
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
var modelVersion = model.GetVersion();
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
{
m_Dict[tensorName] = new LegacyDiscreteActionOutputApplier(actionSpec, seed, allocator);
}
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
{
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
}
}
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories);

93
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


}
readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>();
int m_ApiVersion;
/// <summary>
/// Returns a new TensorGenerators object.

return;
}
var model = (Model)barracudaModel;
m_ApiVersion = model.GetVersion();
// Generator for Inputs
m_Dict[TensorNames.BatchSizePlaceholder] =

public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
{
// Loop through the sensors on a representative agent.
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
var visIndex = 0;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0)
{
// Loop through the sensors on a representative agent.
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
var visIndex = 0;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];
var rank = sensor.GetObservationSpec().Rank;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
{
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetObservationName(sensorIndex);
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
}
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}
if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
var sensor = sensors[sensorIndex];
var shape = sensor.GetObservationShape();
var rank = shape.Length;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetObservationName(sensorIndex);
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
var obsGen = new ObservationGenerator(allocator);
var obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}

12
com.unity.ml-agents/Runtime/Policies/BrainParameters.cs


internal bool hasUpgradedBrainParametersWithActionSpec;
/// <summary>
/// (Deprecated) The number of actions specified by this Brain.
/// </summary>
[Obsolete("NumActions has been deprecated, please use ActionSpec instead.")]
public int NumActions
{
get
{
return ActionSpec.NumContinuousActions > 0 ? ActionSpec.NumContinuousActions : ActionSpec.NumDiscreteActions;
}
}
/// <summary>
/// Deep clones the BrainParameter object.
/// </summary>
/// <returns> A new BrainParameter object with the same values as the original.</returns>

2
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


{
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationShape(), 0);
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0);
sensor.Write(m_ObservationWriter);
}
else

4
com.unity.ml-agents/Runtime/SensorHelper.cs


}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

19
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


/// <summary>
/// A Sensor that allows to observe a variable number of entities.
/// </summary>
public class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor
public class BufferSensor : ISensor, IBuiltInSensor
{
private string m_Name;
private int m_MaxNumObs;

static DimensionProperty[] s_DimensionProperties = new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None
};
ObservationSpec m_ObservationSpec;
/// <summary>
/// Creates the BufferSensor.

m_ObsSize = obsSize;
m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs];
m_CurrentNumObservables = 0;
m_ObservationSpec = ObservationSpec.VariableLength(m_MaxNumObs, m_ObsSize);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new int[] { m_MaxNumObs, m_ObsSize };
}
/// <inheritdoc/>
public DimensionProperty[] GetDimensionProperties()
{
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

32
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


/// <summary>
/// A sensor that wraps a Camera object to generate visual observations for an agent.
/// </summary>
public class CameraSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor
public class CameraSensor : ISensor, IBuiltInSensor
{
Camera m_Camera;
int m_Width;

int[] m_Shape;
private ObservationSpec m_ObservationSpec;
static DimensionProperty[] s_DimensionProperties = new DimensionProperty[] {
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None };
/// <summary>
/// The Camera used for rendering the sensor observations.

m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = GenerateShape(width, height, grayscale);
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_CompressionType = compression;
}

}
/// <summary>
/// Accessor for the size of the sensor data. Will be h x w x 1 for grayscale and
/// h x w x 3 for color.
/// </summary>
/// <returns>Size of each of the three dimensions.</returns>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <summary>
/// Accessor for the dimension properties of a camera sensor. A camera sensor
/// Has translational equivariance along width and hight and no property along
/// the channels dimension.
/// Returns a description of the observations that will be generated by the sensor.
/// The shape will be h x w x 1 for grayscale and h x w x 3 for color.
/// The dimensions have translational equivariance along width and height,
/// and no property along the channels dimension.
public DimensionProperty[] GetDimensionProperties()
public ObservationSpec GetObservationSpec()
return s_DimensionProperties;
return m_ObservationSpec;
}
/// <summary>

70
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


}
/// <summary>
/// The Dimension property flags of the observations
/// </summary>
[System.Flags]
public enum DimensionProperty
{
/// <summary>
/// No properties specified.
/// </summary>
Unspecified = 0,
/// <summary>
/// No Property of the observation in that dimension. Observation can be processed with
/// fully connected networks.
/// </summary>
None = 1,
/// <summary>
/// Means it is suitable to do a convolution in this dimension.
/// </summary>
TranslationalEquivariance = 2,
/// <summary>
/// Means that there can be a variable number of observations in this dimension.
/// The observations are unordered.
/// </summary>
VariableSize = 4,
}
/// <summary>
/// The ObservationType enum of the Sensor.
/// </summary>
public enum ObservationType
{
/// <summary>
/// Collected observations are generic.
/// </summary>
Default = 0,
/// <summary>
/// Collected observations contain goal information.
/// </summary>
Goal = 1,
/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,
/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
}
/// <summary>
/// Returns the size of the observations that will be generated.
/// For example, a sensor that observes the velocity of a rigid body (in 3D) would return
/// new {3}. A sensor that returns an RGB image would return new [] {Height, Width, 3}
/// Returns a description of the observations that will be generated by the sensor.
/// See <see cref="ObservationSpec"/> for more details, and helper methods to create one.
/// <returns>Size of the observations that will be generated.</returns>
int[] GetObservationShape();
/// <returns></returns>
ObservationSpec GetObservationSpec();
/// <summary>
/// Write the observation data directly to the <see cref="ObservationWriter"/>.

/// <returns></returns>
public static int ObservationSize(this ISensor sensor)
{
var shape = sensor.GetObservationShape();
var obsSpec = sensor.GetObservationSpec();
foreach (var dim in shape)
for (var i = 0; i < obsSpec.Rank; i++)
count *= dim;
count *= obsSpec.Shape[i];
}
return count;

41
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="observationSpec">ObservationSpec of the observation to be written</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset)
{
SetTarget(data, observationSpec.Shape, offset);
}
/// <summary>
/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
internal void SetTarget(IList<float> data, int[] shape, int offset)
internal void SetTarget(IList<float> data, InplaceArray<int> shape, int offset)
{
m_Data = data;
m_Offset = offset;

else
{
m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value;
}
}
}
/// <summary>
/// Write the range of floats
/// </summary>
/// <param name="data"></param>
/// <param name="writeOffset">Optional write offset.</param>
[Obsolete("Use AddList() for better performance")]
public void AddRange(IEnumerable<float> data, int writeOffset = 0)
{
if (m_Data != null)
{
int index = 0;
foreach (var val in data)
{
m_Data[index + m_Offset + writeOffset] = val;
index++;
}
}
else
{
int index = 0;
foreach (var val in data)
{
m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val;
index++;
}
}
}

8
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


public class RayPerceptionSensor : ISensor, IBuiltInSensor
{
float[] m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
RayPerceptionInput m_RayPerceptionInput;

void SetNumObservations(int numObservations)
{
m_Shape = new[] { numObservations };
m_ObservationSpec = ObservationSpec.Vector(numObservations);
m_Observations = new float[numObservations];
}

public void Reset() { }
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


// Cached sensor names and shapes.
string m_SensorName;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
int m_NumFloats;
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size)
{

m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_Shape = new[] { size };
m_ObservationSpec = ObservationSpec.Vector(size);
m_NumFloats = size;
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

return m_Shape[0];
return m_NumFloats;
}
internal abstract void WriteReflectedField(ObservationWriter writer);

8
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


RenderTexture m_RenderTexture;
bool m_Grayscale;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var height = renderTexture != null ? renderTexture.height : 0;
m_Grayscale = grayscale;
m_Name = name;
m_Shape = new[] { height, width, grayscale ? 1 : 3 };
m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

21
com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs


/// <returns>Shape of the sensor observation.</returns>
public abstract int[] GetObservationShape();
/// <summary>
/// Whether the observation is visual or not.
/// </summary>
/// <returns>True if the observation is visual, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVisual()
{
var shape = GetObservationShape();
return shape.Length == 3;
}
/// <summary>
/// Whether the observation is vector or not.
/// </summary>
/// <returns>True if the observation is vector, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVector()
{
var shape = GetObservationShape();
return shape.Length == 1;
}
}
}

21
com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs


{
internal class SensorShapeValidator
{
List<int[]> m_SensorShapes;
List<ObservationSpec> m_SensorShapes;
/// <summary>
/// Check that the List Sensors are the same shape as the previous ones.

{
if (m_SensorShapes == null)
{
m_SensorShapes = new List<int[]>(sensors.Count);
m_SensorShapes = new List<ObservationSpec>(sensors.Count);
m_SensorShapes.Add(sensor.GetObservationShape());
m_SensorShapes.Add(sensor.GetObservationSpec());
}
}
else

);
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
{
var cachedShape = m_SensorShapes[i];
var sensorShape = sensors[i].GetObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++)
{
Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes must match.");
}
var cachedSpec = m_SensorShapes[i];
var sensorSpec = sensors[i].GetObservationSpec();
Debug.AssertFormat(
cachedSpec.Shape == sensorSpec.Shape,
"Sensor shapes must match. {0} != {1}",
cachedSpec.Shape,
sensorSpec.Shape
);
}
}
}

48
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


int m_UnstackedObservationSize;
string m_Name;
int[] m_Shape;
int[] m_WrappedShape;
private ObservationSpec m_ObservationSpec;
private ObservationSpec m_WrappedSpec;
/// <summary>
/// Buffer of previous observations

m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
m_WrappedShape = wrapped.GetObservationShape();
m_Shape = new int[m_WrappedShape.Length];
m_WrappedSpec = wrapped.GetObservationSpec();
for (int d = 0; d < m_WrappedShape.Length; d++)
{
m_Shape[d] = m_WrappedShape[d];
}
// Set up the cached observation spec for the StackingSensor
var newShape = m_WrappedSpec.Shape;
m_Shape[m_Shape.Length - 1] *= numStackedObservations;
newShape[newShape.Length - 1] *= numStackedObservations;
m_ObservationSpec = new ObservationSpec(
newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType
);
// Initialize uncompressed buffer anyway in case python trainer does not
// support the compression mapping and has to fall back to uncompressed obs.

m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
}
if (m_Shape.Length != 1)
if (m_WrappedSpec.Rank != 1)
m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
var wrappedShape = m_WrappedSpec.Shape;
m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]);
}
}

// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0);
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0);
if (m_WrappedShape.Length == 1)
if (m_WrappedSpec.Rank == 1)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{

for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
for (var h = 0; h < m_WrappedShape[0]; h++)
for (var h = 0; h < m_WrappedSpec.Shape[0]; h++)
for (var w = 0; w < m_WrappedShape[1]; w++)
for (var w = 0; w < m_WrappedSpec.Shape[1]; w++)
for (var c = 0; c < m_WrappedShape[2]; c++)
for (var c = 0; c < m_WrappedSpec.Shape[2]; c++)
writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
writer[h, w, i * m_WrappedSpec.Shape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations;
numWritten = m_WrappedSpec.Shape[0] * m_WrappedSpec.Shape[1] * m_WrappedSpec.Shape[2] * m_NumStackedObservations;
}
return numWritten;

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

/// </summary>
internal byte[] CreateEmptyPNG()
{
int height = m_WrappedSensor.GetObservationShape()[0];
int width = m_WrappedSensor.GetObservationShape()[1];
var shape = m_WrappedSpec.Shape;
int height = shape[0];
int width = shape[1];
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
Color32[] resetColorArray = texture2D.GetPixels32();
Color32 black = new Color32(0, 0, 0, 0);

// wrapped sensor doesn't have one, use default mapping.
// Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
int[] wrappedMapping = null;
int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2];
int wrappedNumChannel = m_WrappedSpec.Shape[2];
var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
if (sparseChannelSensor != null)
{

23
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


// TODO use float[] instead
// TODO allow setting float[]
List<float> m_Observations;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_Name;
/// <summary>

m_Observations = new List<float>(observationSize);
m_Name = name;
m_Shape = new[] { observationSize };
m_ObservationSpec = ObservationSpec.Vector(observationSize);
var expectedObservations = m_Shape[0];
var expectedObservations = m_ObservationSpec.Shape[0];
if (m_Observations.Count > expectedObservations)
{
// Too many observations, truncate

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
}
/// <summary>
/// Adds a collection of float observations to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
[Obsolete("Use AddObservation(IList<float>) for better performance.")]
public void AddObservation(IEnumerable<float> observation)
{
foreach (var f in observation)
{
AddFloatObs(f);
}
}
/// <summary>

25
com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs


}
}
}
/// <summary>
/// Deprecated, use <see cref="SideChannelManager"/> instead.
/// </summary>
[Obsolete("Use SideChannelManager instead.")]
public static class SideChannelsManager
{
/// <summary>
/// Deprecated, use <see cref="SideChannelManager.RegisterSideChannel"/> instead.
/// </summary>
/// <param name="sideChannel"></param>
public static void RegisterSideChannel(SideChannel sideChannel)
{
SideChannelManager.RegisterSideChannel(sideChannel);
}
/// <summary>
/// Deprecated, use <see cref="SideChannelManager.UnregisterSideChannel"/> instead.
/// </summary>
/// <param name="sideChannel"></param>
public static void UnregisterSideChannel(SideChannel sideChannel)
{
SideChannelManager.UnregisterSideChannel(sideChannel);
}
}
}

45
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs


var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.WriteMask(0, new[] { 1, 2, 3 });
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 2, false);
masker.SetActionEnabled(0, 3, false);
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);

}
[Test]
public void CanOverwriteMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.SetActionEnabled(0, 1, false);
var mask = masker.GetMask();
Assert.IsTrue(mask[1]);
masker.SetActionEnabled(0, 1, true);
Assert.IsFalse(mask[1]);
}
[Test]
masker.WriteMask(1, new[] { 1, 2, 3 });
masker.SetActionEnabled(1, 1, false);
masker.SetActionEnabled(1, 2, false);
masker.SetActionEnabled(1, 3, false);
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);

{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.WriteMask(1, new[] { 1, 2, 3 });
masker.SetActionEnabled(1, 1, false);
masker.SetActionEnabled(1, 2, false);
masker.SetActionEnabled(1, 3, false);
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)

var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(0, new[] { 5 }));
() => masker.SetActionEnabled(0, 5, false));
() => masker.WriteMask(1, new[] { 5 }));
masker.WriteMask(2, new[] { 5 });
() => masker.SetActionEnabled(1, 5, false));
masker.SetActionEnabled(2, 5, false);
() => masker.WriteMask(3, new[] { 1 }));
() => masker.SetActionEnabled(3, 1, false));
masker.WriteMask(0, new[] { 0, 1, 2, 3 });
masker.SetActionEnabled(0, 0, false);
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 2, false);
masker.SetActionEnabled(0, 3, false);
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}

{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.WriteMask(0, new[] { 0, 1 });
masker.WriteMask(0, new[] { 3 });
masker.WriteMask(2, new[] { 1 });
masker.SetActionEnabled(0, 0, false);
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 3, false);
masker.SetActionEnabled(2, 1, false);
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{

6
com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs


public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(i, Masks[i]);
foreach (var actionIndex in Masks[i])
{
actionMask.SetActionEnabled(i, actionIndex, false);
}
}
}

5
com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs


public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(Branch, Mask);
foreach (var actionIndex in Mask)
{
actionMask.SetActionEnabled(Branch, actionIndex, false);
}
}
public void Heuristic(in ActionBuffers actionBuffersOut)

2
com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs


[TestFixture]
public class InferenceAnalyticsTests
{
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx";
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx";
NNModel continuousONNXModel;
Test3DSensorComponent sensor_21_20_3;
Test3DSensorComponent sensor_20_22_3;

2
com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs


[TestFixture]
public class BehaviorParameterTests : IHeuristicProvider
{
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx";
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx";
public void Heuristic(in ActionBuffers actionsOut)
{
// No-op

23
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Demonstrations;

class DummySensor : ISensor
{
public int[] Shape;
public ObservationSpec ObservationSpec;
public SensorCompressionType CompressionType;
internal DummySensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
}
public int Write(ObservationWriter writer)

foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var inplaceShape = InplaceArray<int>.FromList(shape);
dummySensor.Shape = shape;
if (shape.Length == 1)
{
dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]);
}
else if (shape.Length == 3)
{
dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]);
}
else
{
throw new ArgumentOutOfRangeException();
}
obsWriter.SetTarget(new float[128], shape, 0);
obsWriter.SetTarget(new float[128], inplaceShape, 0);
var caps = new UnityRLCapabilities
{

10
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


sensorName = n;
}
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new[] { 0 };
return ObservationSpec.Vector(0);
}
public int Write(ObservationWriter writer)

Assert.AreEqual(numSteps, agent1.sensor1.numWriteCalls);
Assert.AreEqual(numSteps, agent1.sensor2.numCompressedCalls);
// Disable deprecation warnings so we can read/write the old fields.
#pragma warning disable CS0618
// Make sure the Heuristic method read the observation and set the action
Assert.AreEqual(agent1.collectObservationsCallsForEpisode, agent1.GetAction()[0]);
#pragma warning restore CS0618
}
}

6
com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs


{
var bufferSensor = new BufferSensor(20, 4, "testName");
var shape = bufferSensor.GetObservationShape();
var dimProp = bufferSensor.GetDimensionProperties();
var shape = bufferSensor.GetObservationSpec().Shape;
var dimProp = bufferSensor.GetObservationSpec().DimensionProperties;
Assert.AreEqual(shape[0], 20);
Assert.AreEqual(shape[1], 4);
Assert.AreEqual(shape.Length, 2);

var obsWriter = new ObservationWriter();
var obs = bufferSensor.GetObservationProto(obsWriter);
Assert.AreEqual(shape, obs.Shape);
Assert.AreEqual(shape, InplaceArray<int>.FromList(obs.Shape));
Assert.AreEqual(obs.DimensionProperties.Count, 2);
Assert.AreEqual((int)dimProp[0], obs.DimensionProperties[0]);
Assert.AreEqual((int)dimProp[1], obs.DimensionProperties[1]);

3
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
var sensor = cameraComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
var expectedShapeInplace = new InplaceArray<int>(height, width, grayscale ? 1 : 3);
Assert.AreEqual(expectedShapeInplace, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(CameraSensor), sensor.GetType());
}
}

13
com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs


public int Width { get; }
public int Height { get; }
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
public float[,] floatData;
public Float2DSensor(int width, int height, string name)

m_Name = name;
m_Shape = new[] { height, width, 1 };
m_ObservationSpec = ObservationSpec.Visual(height, width, 1);
floatData = new float[Height, Width];
}

Height = floatData.GetLength(0);
Width = floatData.GetLength(1);
m_Name = name;
m_Shape = new[] { Height, Width, 1 };
m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1);
}
public string GetName()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
public byte[] GetCompressedObservation()

var output = new float[12];
var writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
sensor.Write(writer);
for (var i = 0; i < 9; i++)
{

2
com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs


{
ObservationWriter writer = new ObservationWriter();
var buffer = new[] { 0f, 0f, 0f };
var shape = new[] { 3 };
var shape = new InplaceArray<int>(3);
writer.SetTarget(buffer, shape, 0);
// Elementwise writes

20
com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs


var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

2
com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs


Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
var sensor = renderTexComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType());
}
}

27
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


using System.Collections.Generic;
using System.Text.RegularExpressions;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;

public class DummySensor : ISensor
{
string m_Name = "DummySensor";
int[] m_Shape;
ObservationSpec m_ObservationSpec;
m_Shape = new[] { dim1 };
m_ObservationSpec = ObservationSpec.Vector(dim1);
m_Shape = new[] { dim1, dim2, };
m_ObservationSpec = ObservationSpec.VariableLength(dim1, dim2);
m_Shape = new[] { dim1, dim2, dim3 };
m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3);
}
public string GetName()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
public byte[] GetCompressedObservation()

validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}

validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}

var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order

LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}
}

30
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


ISensor wrapped = new VectorSensor(4);
ISensor sensor = new StackingSensor(wrapped, 4);
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
Assert.AreEqual(sensor.GetObservationShape(), new[] { 16 });
Assert.AreEqual(sensor.GetObservationSpec().Shape, new InplaceArray<int>(16));
}
[Test]

{
public SensorCompressionType CompressionType = SensorCompressionType.PNG;
public int[] Mapping;
public int[] Shape;
public ObservationSpec ObservationSpec;
public float[,,] CurrentObservation;
internal Dummy3DSensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
for (var h = 0; h < Shape[0]; h++)
for (var h = 0; h < ObservationSpec.Shape[0]; h++)
for (var w = 0; w < Shape[1]; w++)
for (var w = 0; w < ObservationSpec.Shape[1]; w++)
for (var c = 0; c < Shape[2]; c++)
for (var c = 0; c < ObservationSpec.Shape[2]; c++)
return Shape[0] * Shape[1] * Shape[2];
return ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2];
var flattenedObservation = new float[Shape[0] * Shape[1] * Shape[2]];
writer.SetTarget(flattenedObservation, Shape, 0);
var flattenedObservation = new float[ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]];
writer.SetTarget(flattenedObservation, ObservationSpec.Shape, 0);
Write(writer);
byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z);
return bytes;

// Test mapping with number of layers not being multiple of 3
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new[] { 2, 2, 4 };
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

paddedDummySensor.Shape = new[] { 2, 2, 4 };
paddedDummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 };
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

public void Test3DStacking()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new[] { 2, 1, 2 };
wrapped.ObservationSpec = ObservationSpec.Visual(2, 1, 2);
var sensor = new StackingSensor(wrapped, 2);
// Check the stacking is on the last dimension

public void TestStackedGetCompressedObservation()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new[] { 1, 1, 3 };
wrapped.ObservationSpec = ObservationSpec.Visual(1, 1, 3);
var sensor = new StackingSensor(wrapped, 2);
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } };

public void TestStackingSensorBuiltInSensorType()
{
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new[] { 2, 2, 4 };
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown);

26
docs/Learning-Environment-Design-Agents.md


impossible for the next decision. When the Agent is controlled by a neural
network, the Agent will be unable to perform the specified action. Note that
when the Agent is controlled by its Heuristic, the Agent will still be able to
decide to perform the masked action. In order to mask an action, override the
`Agent.WriteDiscreteActionMask()` virtual method, and call
`WriteMask()` on the provided `IDiscreteActionMask`:
decide to perform the masked action. In order to disallow an action, override
the `Agent.WriteDiscreteActionMask()` virtual method, and call
`SetActionEnabled()` on the provided `IDiscreteActionMask`:
actionMask.WriteMask(branch, actionIndices);
actionMask.SetActionEnabled(branch, actionIndex, isEnabled);
- `branch` is the index (starting at 0) of the branch on which you want to mask
the action
- `actionIndices` is a list of `int` corresponding to the indices of the actions
that the Agent **cannot** perform.
- `branch` is the index (starting at 0) of the branch on which you want to
allow or disallow the action
- `actionIndex` is the index of the action that you want to allow or disallow.
- `isEnabled` is a bool indicating whether the action should be allowed or now.
nothing"_ or _"change weapon"_ for his next decision (since action index 1 and 2
nothing"_ or _"change weapon"_ for their next decision (since action index 1 and 2
WriteMask(0, new int[2]{1,2});
actionMask.SetActionEnabled(0, 1, false);
actionMask.SetActionEnabled(0, 2, false);
- You can call `WriteMask` multiple times if you want to put masks on multiple
- You can call `SetActionEnabled` multiple times if you want to put masks on multiple
- At each step, the state of an action is reset and enabled by default.
- You cannot mask all the actions of a branch.
- You cannot mask actions in continuous control.

components (similar to the ISensor API). The `IActuator` interface and `Agent`
class both implement the `IActionReceiver` interface to allow for backward compatibility
with the current `Agent.OnActionReceived` and `Agent.CollectDiscreteActionMasks` APIs.
with the current `Agent.OnActionReceived`.
This means you will not have to change your code until you decide to use the `IActuator` API.
Like the `ISensor` interface, the `IActuator` interface is intended for advanced users.

50
docs/Migrating.md


# Migrating
## Migrating the package to version 2.0
- If you used any of the APIs that were deprecated before version 2.0, you need to use their replacement. These deprecated APIs have been removed. See the migration steps bellow for specific API replacements.
### IDiscreteActionMask changes
- The interface for disabling specific discrete actions has changed. `IDiscreteActionMask.WriteMask()` was removed,
and replaced with `SetActionEnabled()`. Instead of returning an IEnumerable with indices to disable, you can
now call `SetActionEnabled` for each index to disable (or enable). As an example, if you overrode
`Agent.WriteDiscreteActionMask()` with something that looked like:
```csharp
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
var branch = 2;
var actionsToDisable = new[] {1, 3};
actionMask.WriteMask(branch, actionsToDisable);
}
```
the equivalent code would now be
```csharp
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
var branch = 2;
actionMask.SetActionEnabled(branch, 1, false);
actionMask.SetActionEnabled(branch, 3, false);
}
```
- The `IActuator` interface now implements `IHeuristicProvider`. Please add the corresponding `Heuristic(in ActionBuffers)`
method to your custom Actuator classes.
- The `ISensor.GetObservationShape()` method was removed, and `GetObservationSpec()` was added. You can use
`ObservationSpec.Vector()` or `ObservationSpec.Visual()` to generate `ObservationSpec`s that are equivalent to
the previous shape. For example, if your old ISensor looked like:
```csharp
public override int[] GetObservationShape()
{
return new[] { m_Height, m_Width, m_NumChannels };
}
```
the equivalent code would now be
```csharp
public override ObservationSpec GetObservationSpec()
{
return ObservationSpec.Visual(m_Height, m_Width, m_NumChannels);
}
```
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations
- If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator

1
docs/Readme.md


- [Designing a Learning Environment](Learning-Environment-Design.md)
- [Designing Agents](Learning-Environment-Design-Agents.md)
- [Using an Executable Environment](Learning-Environment-Executable.md)
- [ML-Agents Package Settings](Package-Settings.md)
## Training & Inference

2
ml-agents/mlagents/trainers/torch/distributions.py


).unsqueeze(-1)
def exported_model_output(self):
return self.all_log_prob()
return self.sample()
class GaussianDistribution(nn.Module):

61
ml-agents/mlagents/trainers/torch/model_serialization.py


from typing import Tuple
import threading
from mlagents.torch_utils import torch

observation_specs = self.policy.behavior_spec.observation_specs
batch_dim = [1]
seq_len_dim = [1]
vec_obs_size = 0
for obs_spec in observation_specs:
if len(obs_spec.shape) == 1:
vec_obs_size += obs_spec.shape[0]
num_vis_obs = sum(
1 for obs_spec in observation_specs if len(obs_spec.shape) == 3
)
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in observation_specs.shape)
dummy_vis_obs = [
num_obs = len(observation_specs)
dummy_obs = [
batch_dim + [obs_spec.shape[2], obs_spec.shape[0], obs_spec.shape[1]]
batch_dim + list(ModelSerializer._get_onnx_shape(obs_spec.shape))
if len(obs_spec.shape) == 3
]
dummy_var_len_obs = [
torch.zeros(batch_dim + [obs_spec.shape[0], obs_spec.shape[1]])
for obs_spec in observation_specs
if len(obs_spec.shape) == 2
]
dummy_masks = torch.ones(

batch_dim + seq_len_dim + [self.policy.export_memory_size]
)
self.dummy_input = (
dummy_vec_obs,
dummy_vis_obs,
dummy_var_len_obs,
dummy_masks,
dummy_memories,
)
self.dummy_input = (dummy_obs, dummy_masks, dummy_memories)
self.input_names = [TensorNames.vector_observation_placeholder]
for i in range(num_vis_obs):
self.input_names.append(TensorNames.get_visual_observation_name(i))
for i, obs_spec in enumerate(observation_specs):
if len(obs_spec.shape) == 2:
self.input_names.append(TensorNames.get_observation_name(i))
self.input_names = [TensorNames.get_observation_name(i) for i in range(num_obs)]
self.input_names += [
TensorNames.action_mask_placeholder,
TensorNames.recurrent_in_placeholder,

TensorNames.discrete_action_output_shape,
]
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}})
if (
self.policy.behavior_spec.action_spec.continuous_size == 0
or self.policy.behavior_spec.action_spec.discrete_size == 0
):
self.output_names += [
TensorNames.action_output_deprecated,
TensorNames.is_continuous_control_deprecated,
TensorNames.action_output_shape_deprecated,
]
self.dynamic_axes.update(
{TensorNames.action_output_deprecated: {0: "batch"}}
)
@staticmethod
def _get_onnx_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""
Converts the shape of an observation to be compatible with the NCHW format
of ONNX
"""
if len(shape) == 3:
return shape[2], shape[0], shape[1]
return shape
def export_policy_model(self, output_filepath: str) -> None:
"""

44
ml-agents/mlagents/trainers/torch/networks.py


@abc.abstractmethod
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:

class SimpleActor(nn.Module, Actor):
MODEL_EXPORT_VERSION = 3
def __init__(
self,
observation_specs: List[ObservationSpec],

super().__init__()
self.action_spec = action_spec
self.version_number = torch.nn.Parameter(
torch.Tensor([2.0]), requires_grad=False
torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False
)
self.is_continuous_int_deprecated = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False

)
# TODO: export list of branch sizes instead of sum
torch.Tensor([sum(self.action_spec.discrete_branches)]), requires_grad=False
torch.Tensor([self.action_spec.discrete_branches]), requires_grad=False
)
self.act_size_vector_deprecated = torch.nn.Parameter(
torch.Tensor(

def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:

At this moment, torch.onnx.export() doesn't accept None as tensor to be exported,
so the size of return tuple varies with action spec.
"""
# This code will convert the vec and vis obs into a list of inputs for the network
concatenated_vec_obs = vec_inputs[0]
inputs = []
start = 0
end = 0
vis_index = 0
var_len_index = 0
for i, enc in enumerate(self.network_body.observation_encoder.processors):
if isinstance(enc, VectorInput):
# This is a vec_obs
vec_size = self.network_body.observation_encoder.embedding_sizes[i]
end = start + vec_size
inputs.append(concatenated_vec_obs[:, start:end])
start = end
elif isinstance(enc, EntityEmbedding):
inputs.append(var_len_inputs[var_len_index])
var_len_index += 1
else: # visual input
inputs.append(vis_inputs[vis_index])
vis_index += 1
# End of code to convert the vec and vis obs into a list of inputs for the network
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=1
)

export_out += [cont_action_out, self.continuous_act_size_vector]
if self.action_spec.discrete_size > 0:
export_out += [disc_action_out, self.discrete_act_size_vector]
# Only export deprecated nodes with non-hybrid action spec
if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0:
export_out += [
action_out_deprecated,
self.is_continuous_int_deprecated,
self.act_size_vector_deprecated,
]
if self.network_body.memory_size > 0:
export_out += [memories_out]
return tuple(export_out)

2
com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta


fileFormatVersion: 2
guid: 297e9ec12d6de45adbcf6dea1a9de019
guid: 8e1cdc27e533749fabc04b3cdeb93501
MonoImporter:
externalObjects: {}
serializedVersion: 2

41
com.unity.ml-agents/Tests/Editor/Inference/DiscreteActionOutputApplierTest.cs


namespace Unity.MLAgents.Tests
{
public class DiscreteActionOutputApplierTest
{
[Test]

var applier = new DiscreteActionOutputApplier(actionSpec, 2020, null);
var agentIds = new List<int> { 42, 1337 };
var actionBuffers = new Dictionary<int, ActionBuffers>();
actionBuffers[42] = new ActionBuffers(actionSpec);
actionBuffers[1337] = new ActionBuffers(actionSpec);
var actionTensor = new TensorProxy
{
data = new Tensor(
2,
2,
new[]
{
2.0f, // Agent 0, branch 0
1.0f, // Agent 0, branch 1
0.0f, // Agent 1, branch 0
0.0f // Agent 1, branch 1
}),
shape = new long[] { 2, 2 },
valueType = TensorProxy.TensorType.FloatingPoint
};
applier.Apply(actionTensor, agentIds, actionBuffers);
Assert.AreEqual(2, actionBuffers[42].DiscreteActions[0]);
Assert.AreEqual(1, actionBuffers[42].DiscreteActions[1]);
Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[0]);
Assert.AreEqual(0, actionBuffers[1337].DiscreteActions[1]);
}
}
public class LegacyDiscreteActionOutputApplierTest
{
[Test]
public void TestDiscreteApply()
{
var actionSpec = ActionSpec.MakeDiscrete(3, 2);
const float smallLogProb = -1000.0f;
const float largeLogProb = -1.0f;

valueType = TensorProxy.TensorType.FloatingPoint
};
var applier = new DiscreteActionOutputApplier(actionSpec, 2020, null);
var applier = new LegacyDiscreteActionOutputApplier(actionSpec, 2020, null);
var agentIds = new List<int> { 42, 1337 };
var actionBuffers = new Dictionary<int, ActionBuffers>();
actionBuffers[42] = new ActionBuffers(actionSpec);

77
com.unity.ml-agents/Tests/Editor/Inference/EditModeTestInternalBrainTensorApplier.cs


}
[Test]
public void ApplyDiscreteActionOutput()
public void ApplyDiscreteActionOutputLegacy()
{
var actionSpec = ActionSpec.MakeDiscrete(2, 3);
var inputTensor = new TensorProxy()

new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var alloc = new TensorCachingAllocator();
var applier = new LegacyDiscreteActionOutputApplier(actionSpec, 0, alloc);
var agentIds = new List<int>() { 0, 1 };
// Dictionary from AgentId to Action
var actionDict = new Dictionary<int, ActionBuffers>() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } };
applier.Apply(inputTensor, agentIds, actionDict);
Assert.AreEqual(actionDict[0].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[0].DiscreteActions[1], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[1], 2);
alloc.Dispose();
}
[Test]
public void ApplyDiscreteActionOutput()
{
var actionSpec = ActionSpec.MakeDiscrete(2, 3);
var inputTensor = new TensorProxy()
{
shape = new long[] { 2, 2 },
data = new Tensor(
2,
2,
new[] { 1f, 1f, 1f, 2f }),
};
var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(actionSpec, 0, alloc);
var agentIds = new List<int>() { 0, 1 };

}
[Test]
public void ApplyHybridActionOutput()
public void ApplyHybridActionOutputLegacy()
{
var actionSpec = new ActionSpec(3, new[] { 2, 3 });
var continuousInputTensor = new TensorProxy()

2,
5,
new[] { 0.5f, 22.5f, 0.1f, 5f, 1f, 4f, 5f, 6f, 7f, 8f })
};
var continuousApplier = new ContinuousActionOutputApplier(actionSpec);
var alloc = new TensorCachingAllocator();
var discreteApplier = new LegacyDiscreteActionOutputApplier(actionSpec, 0, alloc);
var agentIds = new List<int>() { 0, 1 };
// Dictionary from AgentId to Action
var actionDict = new Dictionary<int, ActionBuffers>() { { 0, ActionBuffers.Empty }, { 1, ActionBuffers.Empty } };
continuousApplier.Apply(continuousInputTensor, agentIds, actionDict);
discreteApplier.Apply(discreteInputTensor, agentIds, actionDict);
Assert.AreEqual(actionDict[0].ContinuousActions[0], 1);
Assert.AreEqual(actionDict[0].ContinuousActions[1], 2);
Assert.AreEqual(actionDict[0].ContinuousActions[2], 3);
Assert.AreEqual(actionDict[0].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[0].DiscreteActions[1], 1);
Assert.AreEqual(actionDict[1].ContinuousActions[0], 4);
Assert.AreEqual(actionDict[1].ContinuousActions[1], 5);
Assert.AreEqual(actionDict[1].ContinuousActions[2], 6);
Assert.AreEqual(actionDict[1].DiscreteActions[0], 1);
Assert.AreEqual(actionDict[1].DiscreteActions[1], 2);
alloc.Dispose();
}
[Test]
public void ApplyHybridActionOutput()
{
var actionSpec = new ActionSpec(3, new[] { 2, 3 });
var continuousInputTensor = new TensorProxy()
{
shape = new long[] { 2, 3 },
data = new Tensor(2, 3, new float[] { 1, 2, 3, 4, 5, 6 })
};
var discreteInputTensor = new TensorProxy()
{
shape = new long[] { 2, 2 },
data = new Tensor(
2,
2,
new[] { 1f, 1f, 1f, 2f }),
};
var continuousApplier = new ContinuousActionOutputApplier(actionSpec);
var alloc = new TensorCachingAllocator();

12
com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs


[TestFixture]
public class ModelRunnerTest
{
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx";
const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx";
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx";
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn";
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn";
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_v1_0.onnx";
const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_v1_0.onnx";
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx";
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn";
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn";
NNModel continuousONNXModel;
NNModel discreteONNXModel;
NNModel hybridONNXModel;

Test3DSensorComponent sensor_20_22_3;
ActionSpec GetContinuous2vis8vec2actionActionSpec()
{

7
DevProject/Assets/ML-Agents/Scripts/Tests/Editor/Editor.asmdef.meta


fileFormatVersion: 2
guid: 5b142e67c2d6b4b1e928e4d54f01a596
AssemblyDefinitionImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
DevProject/Assets/ML-Agents/Scripts/Tests/Editor/MLAgentsSettings.meta


fileFormatVersion: 2
guid: 1fc80f44976bc4177a9afaa0a38abab3
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

75
com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs


using System.Linq;
using UnityEngine;
using UnityEditor;
using UnityEditor.Build;
using UnityEditor.Build.Reporting;
namespace Unity.MLAgents.Editor
{
internal class MLAgentsSettingsBuildProvider : IPreprocessBuildWithReport, IPostprocessBuildWithReport
{
private MLAgentsSettings m_SettingsAddedToPreloadedAssets;
public int callbackOrder => 0;
public void OnPreprocessBuild(BuildReport report)
{
var wasDirty = IsPlayerSettingsDirty();
m_SettingsAddedToPreloadedAssets = null;
var preloadedAssets = PlayerSettings.GetPreloadedAssets().ToList();
if (!preloadedAssets.Contains(MLAgentsSettingsManager.Settings))
{
m_SettingsAddedToPreloadedAssets = MLAgentsSettingsManager.Settings;
preloadedAssets.Add(m_SettingsAddedToPreloadedAssets);
PlayerSettings.SetPreloadedAssets(preloadedAssets.ToArray());
}
if (!wasDirty)
ClearPlayerSettingsDirtyFlag();
}
public void OnPostprocessBuild(BuildReport report)
{
if (m_SettingsAddedToPreloadedAssets == null)
return;
var wasDirty = IsPlayerSettingsDirty();
var preloadedAssets = PlayerSettings.GetPreloadedAssets().ToList();
if (preloadedAssets.Contains(m_SettingsAddedToPreloadedAssets))
{
preloadedAssets.Remove(m_SettingsAddedToPreloadedAssets);
PlayerSettings.SetPreloadedAssets(preloadedAssets.ToArray());
}
m_SettingsAddedToPreloadedAssets = null;
if (!wasDirty)
ClearPlayerSettingsDirtyFlag();
}
private static bool IsPlayerSettingsDirty()
{
#if UNITY_2019_OR_NEWER
var settings = Resources.FindObjectsOfTypeAll<PlayerSettings>();
if (settings != null && settings.Length > 0)
return EditorUtility.IsDirty(settings[0]);
return false;
#else
return false;
#endif
}
private static void ClearPlayerSettingsDirtyFlag()
{
#if UNITY_2019_OR_NEWER
var settings = Resources.FindObjectsOfTypeAll<PlayerSettings>();
if (settings != null && settings.Length > 0)
EditorUtility.ClearDirty(settings[0]);
#endif
}
}
}

11
com.unity.ml-agents/Editor/MLAgentsSettingsBuildProvider.cs.meta


fileFormatVersion: 2
guid: bd59ff34305fa4259a2735e08afdb424
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

198
com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs


using System;
using System.Linq;
using System.IO;
using System.Runtime.CompilerServices;
using UnityEngine;
using UnityEditor;
#if UNITY_2019_4_OR_NEWER
using UnityEngine.UIElements;
#else
using UnityEngine.Experimental.UIElements;
#endif
[assembly: InternalsVisibleTo("Unity.ML-Agents.DevTests.Editor")]
namespace Unity.MLAgents.Editor
{
internal class MLAgentsSettingsProvider : SettingsProvider, IDisposable
{
const string k_SettingsPath = "Project/ML-Agents";
private static MLAgentsSettingsProvider s_Instance;
private string[] m_AvailableSettingsAssets;
private int m_CurrentSelectedSettingsAsset;
private SerializedObject m_SettingsObject;
[SerializeField]
private MLAgentsSettings m_Settings;
private MLAgentsSettingsProvider(string path, SettingsScope scope = SettingsScope.Project)
: base(path, scope)
{
s_Instance = this;
}
[SettingsProvider]
public static SettingsProvider CreateMLAgentsSettingsProvider()
{
return new MLAgentsSettingsProvider(k_SettingsPath, SettingsScope.Project);
}
public override void OnActivate(string searchContext, VisualElement rootElement)
{
base.OnActivate(searchContext, rootElement);
MLAgentsSettingsManager.OnSettingsChange += Reinitialize;
}
public override void OnDeactivate()
{
base.OnDeactivate();
MLAgentsSettingsManager.OnSettingsChange -= Reinitialize;
}
public void Dispose()
{
m_SettingsObject?.Dispose();
}
public override void OnTitleBarGUI()
{
if (EditorGUILayout.DropdownButton(EditorGUIUtility.IconContent("_Popup"), FocusType.Passive, EditorStyles.label))
{
var menu = new GenericMenu();
for (var i = 0; i < m_AvailableSettingsAssets.Length; i++)
{
menu.AddItem(ExtractDisplayName(m_AvailableSettingsAssets[i]), m_CurrentSelectedSettingsAsset == i, (path) =>
{
MLAgentsSettingsManager.Settings = AssetDatabase.LoadAssetAtPath<MLAgentsSettings>((string)path);
}, m_AvailableSettingsAssets[i]);
}
menu.AddSeparator("");
menu.AddItem(new GUIContent("New Settings Asset…"), false, CreateNewSettingsAsset);
menu.ShowAsContext();
Event.current.Use();
}
}
private GUIContent ExtractDisplayName(string name)
{
if (name.StartsWith("Assets/"))
name = name.Substring("Assets/".Length);
if (name.EndsWith(".asset"))
name = name.Substring(0, name.Length - ".asset".Length);
if (name.EndsWith(".mlagents.settings"))
name = name.Substring(0, name.Length - ".mlagents.settings".Length);
// Ugly hack: GenericMenu interprets "/" as a submenu path. But luckily, "/" is not the only slash we have in Unicode.
return new GUIContent(name.Replace("/", "\u29f8"));
}
private void CreateNewSettingsAsset()
{
// Asset database always use forward slashes. Use forward slashes for all the paths.
var projectName = PlayerSettings.productName;
var path = EditorUtility.SaveFilePanel("Create ML-Agents Settings File", "Assets",
projectName + ".mlagents.settings", "asset");
if (string.IsNullOrEmpty(path))
{
return;
}
path = path.Replace("\\", "/"); // Make sure we only get '/' separators.
var assetPath = Application.dataPath + "/";
if (!path.StartsWith(assetPath, StringComparison.CurrentCultureIgnoreCase))
{
Debug.LogError(string.Format(
"Settings must be stored in Assets folder of the project (got: '{0}')", path));
return;
}
var extension = Path.GetExtension(path);
if (string.Compare(extension, ".asset", StringComparison.InvariantCultureIgnoreCase) != 0)
{
path += ".asset";
}
var relativePath = "Assets/" + path.Substring(assetPath.Length);
CreateNewSettingsAsset(relativePath);
}
private static void CreateNewSettingsAsset(string relativePath)
{
var settings = ScriptableObject.CreateInstance<MLAgentsSettings>();
AssetDatabase.CreateAsset(settings, relativePath);
EditorGUIUtility.PingObject(settings);
// Install the settings. This will lead to an MLAgentsManager.OnSettingsChange event
// which in turn will cause this Provider to reinitialize
MLAgentsSettingsManager.Settings = settings;
}
public override void OnGUI(string searchContext)
{
if (m_Settings == null)
{
InitializeWithCurrentSettings();
}
if (m_AvailableSettingsAssets.Length == 0)
{
EditorGUILayout.HelpBox(
"Click the button below to create a settings asset you can edit.",
MessageType.Info);
if (GUILayout.Button("Create settings asset", GUILayout.Height(30)))
CreateNewSettingsAsset();
GUILayout.Space(20);
}
using (new EditorGUI.DisabledScope(m_AvailableSettingsAssets.Length == 0))
{
EditorGUI.BeginChangeCheck();
EditorGUILayout.LabelField("Trainer Settings", EditorStyles.boldLabel);
EditorGUI.indentLevel++;
EditorGUILayout.PropertyField(m_SettingsObject.FindProperty("m_ConnectTrainer"), new GUIContent("Connect to Trainer"));
EditorGUILayout.PropertyField(m_SettingsObject.FindProperty("m_EditorPort"), new GUIContent("Editor Training Port"));
EditorGUI.indentLevel--;
if (EditorGUI.EndChangeCheck())
m_SettingsObject.ApplyModifiedProperties();
}
}
internal void InitializeWithCurrentSettings()
{
m_AvailableSettingsAssets = FindSettingsInProject();
m_Settings = MLAgentsSettingsManager.Settings;
var currentSettingsPath = AssetDatabase.GetAssetPath(m_Settings);
if (string.IsNullOrEmpty(currentSettingsPath))
{
if (m_AvailableSettingsAssets.Length > 0)
{
m_CurrentSelectedSettingsAsset = 0;
m_Settings = AssetDatabase.LoadAssetAtPath<MLAgentsSettings>(m_AvailableSettingsAssets[0]);
MLAgentsSettingsManager.Settings = m_Settings;
}
}
else
{
var settingsList = m_AvailableSettingsAssets.ToList();
m_CurrentSelectedSettingsAsset = settingsList.IndexOf(currentSettingsPath);
EditorBuildSettings.AddConfigObject(MLAgentsSettingsManager.EditorBuildSettingsConfigKey, m_Settings, true);
}
m_SettingsObject = new SerializedObject(m_Settings);
}
private static string[] FindSettingsInProject()
{
var guids = AssetDatabase.FindAssets("t:MLAgentsSettings");
return guids.Select(guid => AssetDatabase.GUIDToAssetPath(guid)).ToArray();
}
private void Reinitialize()
{
if (m_Settings != null && MLAgentsSettingsManager.Settings != m_Settings)
{
InitializeWithCurrentSettings();
}
Repaint();
}
}
}

11
com.unity.ml-agents/Editor/MLAgentsSettingsProvider.cs.meta


fileFormatVersion: 2
guid: 162489862d7f64a40990a0c06bb73bd0
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

241
com.unity.ml-agents/Runtime/InplaceArray.cs


using System;
using System.Collections.Generic;
namespace Unity.MLAgents
{
/// <summary>
/// An array-like object that stores up to four elements.
/// This is a value type that does not allocate any additional memory.
/// </summary>
/// <remarks>
/// This does not implement any interfaces such as IList, in order to avoid any accidental boxing allocations.
/// </remarks>
/// <typeparam name="T"></typeparam>
public struct InplaceArray<T> : IEquatable<InplaceArray<T>> where T : struct
{
private const int k_MaxLength = 4;
private readonly int m_Length;
private T m_Elem0;
private T m_Elem1;
private T m_Elem2;
private T m_Elem3;
/// <summary>
/// Create a length-1 array.
/// </summary>
/// <param name="elem0"></param>
public InplaceArray(T elem0)
{
m_Length = 1;
m_Elem0 = elem0;
m_Elem1 = new T { };
m_Elem2 = new T { };
m_Elem3 = new T { };
}
/// <summary>
/// Create a length-2 array.
/// </summary>
/// <param name="elem0"></param>
/// <param name="elem1"></param>
public InplaceArray(T elem0, T elem1)
{
m_Length = 2;
m_Elem0 = elem0;
m_Elem1 = elem1;
m_Elem2 = new T { };
m_Elem3 = new T { };
}
/// <summary>
/// Create a length-3 array.
/// </summary>
/// <param name="elem0"></param>
/// <param name="elem1"></param>
/// <param name="elem2"></param>
public InplaceArray(T elem0, T elem1, T elem2)
{
m_Length = 3;
m_Elem0 = elem0;
m_Elem1 = elem1;
m_Elem2 = elem2;
m_Elem3 = new T { };
}
/// <summary>
/// Create a length-3 array.
/// </summary>
/// <param name="elem0"></param>
/// <param name="elem1"></param>
/// <param name="elem2"></param>
/// <param name="elem3"></param>
public InplaceArray(T elem0, T elem1, T elem2, T elem3)
{
m_Length = 4;
m_Elem0 = elem0;
m_Elem1 = elem1;
m_Elem2 = elem2;
m_Elem3 = elem3;
}
/// <summary>
/// Construct an InplaceArray from an IList (e.g. Array or List).
/// The source must be non-empty and have at most 4 elements.
/// </summary>
/// <param name="elems"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public static InplaceArray<T> FromList(IList<T> elems)
{
switch (elems.Count)
{
case 1:
return new InplaceArray<T>(elems[0]);
case 2:
return new InplaceArray<T>(elems[0], elems[1]);
case 3:
return new InplaceArray<T>(elems[0], elems[1], elems[2]);
case 4:
return new InplaceArray<T>(elems[0], elems[1], elems[2], elems[3]);
default:
throw new ArgumentOutOfRangeException();
}
}
/// <summary>
/// Per-element access.
/// </summary>
/// <param name="index"></param>
/// <exception cref="IndexOutOfRangeException"></exception>
public T this[int index]
{
get
{
if (index >= Length)
{
throw new IndexOutOfRangeException();
}
switch (index)
{
case 0:
return m_Elem0;
case 1:
return m_Elem1;
case 2:
return m_Elem2;
case 3:
return m_Elem3;
default:
throw new IndexOutOfRangeException();
}
}
set
{
if (index >= Length)
{
throw new IndexOutOfRangeException();
}
switch (index)
{
case 0:
m_Elem0 = value;
break;
case 1:
m_Elem1 = value;
break;
case 2:
m_Elem2 = value;
break;
case 3:
m_Elem3 = value;
break;
default:
throw new IndexOutOfRangeException();
}
}
}
/// <summary>
/// The length of the array.
/// </summary>
public int Length
{
get => m_Length;
}
/// <summary>
/// Returns a string representation of the array's elements.
/// </summary>
/// <returns></returns>
/// <exception cref="IndexOutOfRangeException"></exception>
public override string ToString()
{
switch (m_Length)
{
case 1:
return $"[{m_Elem0}]";
case 2:
return $"[{m_Elem0}, {m_Elem1}]";
case 3:
return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}]";
case 4:
return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}, {m_Elem3}]";
default:
throw new IndexOutOfRangeException();
}
}
/// <summary>
/// Check that the arrays have the same length and have all equal values.
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>Whether the arrays are equivalent.</returns>
public static bool operator ==(InplaceArray<T> lhs, InplaceArray<T> rhs)
{
return lhs.Equals(rhs);
}
/// <summary>
/// Check that the arrays are not equivalent.
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>Whether the arrays are not equivalent</returns>
public static bool operator !=(InplaceArray<T> lhs, InplaceArray<T> rhs) => !lhs.Equals(rhs);
/// <summary>
/// Check that the arrays are equivalent.
/// </summary>
/// <param name="other"></param>
/// <returns>Whether the arrays are not equivalent</returns>
public override bool Equals(object other) => other is InplaceArray<T> other1 && this.Equals(other1);
/// <summary>
/// Check that the arrays are equivalent.
/// </summary>
/// <param name="other"></param>
/// <returns>Whether the arrays are not equivalent</returns>
public bool Equals(InplaceArray<T> other)
{
// See https://montemagno.com/optimizing-c-struct-equality-with-iequatable/
var thisTuple = (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length);
var otherTuple = (other.m_Elem0, other.m_Elem1, other.m_Elem2, other.m_Elem3, other.Length);
return thisTuple.Equals(otherTuple);
}
/// <summary>
/// Get a hashcode for the array.
/// </summary>
/// <returns></returns>
public override int GetHashCode()
{
return (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length).GetHashCode();
}
}
}

3
com.unity.ml-agents/Runtime/InplaceArray.cs.meta


fileFormatVersion: 2
guid: c1a80abee18a41c8aee89aeb33f5985d
timeCreated: 1615506199

41
com.unity.ml-agents/Runtime/MLAgentsSettings.cs


using UnityEngine;
using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Unity.ML-Agents.DevTests.Editor")]
namespace Unity.MLAgents
{
internal class MLAgentsSettings : ScriptableObject
{
[SerializeField]
private bool m_ConnectTrainer = true;
[SerializeField]
private int m_EditorPort = 5004;
public bool ConnectTrainer
{
get { return m_ConnectTrainer; }
set
{
m_ConnectTrainer = value;
OnChange();
}
}
public int EditorPort
{
get { return m_EditorPort; }
set
{
m_EditorPort = value;
OnChange();
}
}
internal void OnChange()
{
if (MLAgentsSettingsManager.Settings == this)
MLAgentsSettingsManager.ApplySettings();
}
}
}

11
com.unity.ml-agents/Runtime/MLAgentsSettings.cs.meta


fileFormatVersion: 2
guid: 71515ce028aaa4b4cb6bee13e96ef6f3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

91
com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs


using System;
using System.Linq;
using UnityEngine;
#if UNITY_EDITOR
using UnityEditor;
#endif
namespace Unity.MLAgents
{
#if UNITY_EDITOR
[InitializeOnLoad]
#endif
internal static class MLAgentsSettingsManager
{
internal static event Action OnSettingsChange;
internal const string EditorBuildSettingsConfigKey = "com.unity.ml-agents.settings";
private static MLAgentsSettings s_Settings;
// setter will trigger callback for refreshing editor UI if using editor
public static MLAgentsSettings Settings
{
get
{
if (s_Settings == null)
{
Initialize();
}
return s_Settings;
}
set
{
Debug.Assert(value != null);
#if UNITY_EDITOR
if (!string.IsNullOrEmpty(AssetDatabase.GetAssetPath(value)))
{
EditorBuildSettings.AddConfigObject(EditorBuildSettingsConfigKey, value, true);
}
#endif
s_Settings = value;
ApplySettings();
}
}
static MLAgentsSettingsManager()
{
Initialize();
}
static void Initialize()
{
#if UNITY_EDITOR
InitializeInEditor();
#else
InitializeInPlayer();
#endif
}
#if UNITY_EDITOR
internal static void InitializeInEditor()
{
var settings = ScriptableObject.CreateInstance<MLAgentsSettings>();
if (EditorBuildSettings.TryGetConfigObject(EditorBuildSettingsConfigKey,
out MLAgentsSettings settingsAsset))
{
if (settingsAsset != null)
{
settings = settingsAsset;
}
}
Settings = settings;
}
#else
internal static void InitializeInPlayer()
{
Settings = Resources.FindObjectsOfTypeAll<MLAgentsSettings>().FirstOrDefault() ?? ScriptableObject.CreateInstance<MLAgentsSettings>();
}
#endif
internal static void ApplySettings()
{
OnSettingsChange?.Invoke();
}
internal static void Destroy()
{
s_Settings = null;
OnSettingsChange = null;
}
}
}

11
com.unity.ml-agents/Runtime/MLAgentsSettingsManager.cs.meta


fileFormatVersion: 2
guid: be40451993af54e3c84c7113140fdf2c
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

140
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs


using Unity.Barracuda;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A description of the observations that an ISensor produces.
/// This includes the size of the observation, the properties of each dimension, and how the observation
/// should be used for training.
/// </summary>
public struct ObservationSpec
{
internal readonly InplaceArray<int> m_Shape;
/// <summary>
/// The size of the observations that will be generated.
/// For example, a sensor that observes the velocity of a rigid body (in 3D) would use [3].
/// A sensor that returns an RGB image would use [Height, Width, 3].
/// </summary>
public InplaceArray<int> Shape
{
get => m_Shape;
}
internal readonly InplaceArray<DimensionProperty> m_DimensionProperties;
/// <summary>
/// The properties of each dimensions of the observation.
/// The length of the array must be equal to the rank of the observation tensor.
/// </summary>
/// <remarks>
/// It is generally recommended to use default values provided by helper functions,
/// as not all combinations of DimensionProperty may be supported by the trainer.
/// </remarks>
public InplaceArray<DimensionProperty> DimensionProperties
{
get => m_DimensionProperties;
}
internal ObservationType m_ObservationType;
/// <summary>
/// The type of the observation, e.g. whether they are generic or
/// help determine the goal for the Agent.
/// </summary>
public ObservationType ObservationType
{
get => m_ObservationType;
}
/// <summary>
/// The number of dimensions of the observation.
/// </summary>
public int Rank
{
get { return Shape.Length; }
}
/// <summary>
/// Construct an ObservationSpec for 1-D observations of the requested length.
/// </summary>
/// <param name="length"></param>
/// <param name="obsType"></param>
/// <returns></returns>
public static ObservationSpec Vector(int length, ObservationType obsType = ObservationType.Default)
{
return new ObservationSpec(
new InplaceArray<int>(length),
new InplaceArray<DimensionProperty>(DimensionProperty.None),
obsType
);
}
/// <summary>
/// Construct an ObservationSpec for variable-length observations.
/// </summary>
/// <param name="obsSize"></param>
/// <param name="maxNumObs"></param>
/// <returns></returns>
public static ObservationSpec VariableLength(int obsSize, int maxNumObs)
{
var dimProps = new InplaceArray<DimensionProperty>(
DimensionProperty.VariableSize,
DimensionProperty.None
);
return new ObservationSpec(
new InplaceArray<int>(obsSize, maxNumObs),
dimProps
);
}
/// <summary>
/// Construct an ObservationSpec for visual-like observations, e.g. observations
/// with a height, width, and possible multiple channels.
/// </summary>
/// <param name="height"></param>
/// <param name="width"></param>
/// <param name="channels"></param>
/// <param name="obsType"></param>
/// <returns></returns>
public static ObservationSpec Visual(int height, int width, int channels, ObservationType obsType = ObservationType.Default)
{
var dimProps = new InplaceArray<DimensionProperty>(
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None
);
return new ObservationSpec(
new InplaceArray<int>(height, width, channels),
dimProps,
obsType
);
}
/// <summary>
/// Create a general ObservationSpec from the shape, dimension properties, and observation type.
/// </summary>
/// <remarks>
/// Note that not all combinations of DimensionProperty may be supported by the trainer.
/// shape and dimensionProperties must have the same size.
/// </remarks>
/// <param name="shape"></param>
/// <param name="dimensionProperties"></param>
/// <param name="observationType"></param>
/// <exception cref="UnityAgentsException"></exception>
public ObservationSpec(
InplaceArray<int> shape,
InplaceArray<DimensionProperty> dimensionProperties,
ObservationType observationType = ObservationType.Default
)
{
if (shape.Length != dimensionProperties.Length)
{
throw new UnityAgentsException("shape and dimensionProperties must have the same length.");
}
m_Shape = shape;
m_DimensionProperties = dimensionProperties;
m_ObservationType = observationType;
}
}
}

3
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta


fileFormatVersion: 2
guid: cc1734d60fd5485ead94247cb206aa35
timeCreated: 1615412644

8
com.unity.ml-agents/Tests/Editor/Inference.meta


fileFormatVersion: 2
guid: 7b8fc3bc69d3a4cd9a66ad334f944fb2
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

部分文件因为文件数量过多而无法显示

正在加载...
取消
保存