浏览代码

Merge branch 'develop-base-teammanager' into develop-agentprocessor-teammanager

/develop/coma2/samenet
Ervin Teng 3 年前
当前提交
b6f88d6d
共有 191 个文件被更改,包括 1509 次插入687 次删除
  1. 2
      .yamato/com.unity.ml-agents-performance.yml
  2. 1
      .yamato/gym-interface-test.yml
  3. 16
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs
  4. 3
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  5. 25
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
  6. 25
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
  7. 25
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
  8. 10
      Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
  9. 166
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs
  10. 13
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
  11. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  12. 2
      Project/ProjectSettings/UnityConnectSettings.asset
  13. 51
      README.md
  14. 74
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
  15. 10
      com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs
  16. 8
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  17. 9
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  18. 9
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  19. 25
      com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs
  20. 34
      com.unity.ml-agents/CHANGELOG.md
  21. 7
      com.unity.ml-agents/Runtime/Academy.cs
  22. 13
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  23. 50
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  24. 27
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
  25. 50
      com.unity.ml-agents/Runtime/Agent.cs
  26. 68
      com.unity.ml-agents/Runtime/Analytics/Events.cs
  27. 14
      com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs
  28. 64
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  29. 5
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  30. 5
      com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
  31. 26
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
  32. 39
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
  33. 52
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
  34. 21
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  35. 70
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  36. 126
      com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs
  37. 46
      com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
  38. 7
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  39. 63
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  40. 1
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  41. 9
      com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
  42. 11
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  43. 14
      com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
  44. 8
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  45. 15
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  46. 24
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  47. 10
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  48. 9
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  49. 9
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  50. 3
      com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
  51. 7
      com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
  52. 11
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  53. 24
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  54. 12
      com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs
  55. 26
      com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs
  56. 18
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  57. 11
      com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
  58. 19
      com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
  59. 19
      com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs
  60. 6
      com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
  61. 32
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  62. 2
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  63. 9
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  64. 2
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  65. 12
      com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
  66. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
  67. 15
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  68. 4
      com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
  69. 2
      docs/Background-Unity.md
  70. 11
      docs/Migrating.md
  71. 9
      docs/Python-API.md
  72. 10
      docs/Training-ML-Agents.md
  73. 16
      gym-unity/gym_unity/envs/__init__.py
  74. 6
      gym-unity/gym_unity/tests/test_gym.py
  75. 52
      ml-agents-envs/mlagents_envs/base_env.py
  76. 17
      ml-agents-envs/mlagents_envs/communicator.py
  77. 6
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
  78. 4
      ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
  79. 11
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
  80. 6
      ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
  81. 56
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
  82. 27
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
  83. 8
      ml-agents-envs/mlagents_envs/env_utils.py
  84. 55
      ml-agents-envs/mlagents_envs/environment.py
  85. 12
      ml-agents-envs/mlagents_envs/mock_communicator.py
  86. 47
      ml-agents-envs/mlagents_envs/rpc_communicator.py
  87. 44
      ml-agents-envs/mlagents_envs/rpc_utils.py
  88. 2
      ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py
  89. 2
      ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py
  90. 14
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  91. 54
      ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py
  92. 19
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  93. 6
      ml-agents-envs/mlagents_envs/tests/test_steps.py
  94. 1
      ml-agents/mlagents/torch_utils/__init__.py
  95. 37
      ml-agents/mlagents/torch_utils/torch.py
  96. 15
      ml-agents/mlagents/trainers/cli_utils.py
  97. 9
      ml-agents/mlagents/trainers/demo_loader.py
  98. 12
      ml-agents/mlagents/trainers/env_manager.py
  99. 15
      ml-agents/mlagents/trainers/learn.py
  100. 4
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

2
.yamato/com.unity.ml-agents-performance.yml


commands:
- python3 -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade
- unity-downloader-cli -u {{ editor.version }} -c editor --wait --fast
- curl -s https://artifactory.internal.unity3d.com/core-automation/tools/utr-standalone/utr --output utr
- curl -s https://artifactory.prd.it.unity3d.com/artifactory/unity-tools-local/utr-standalone/utr --output utr
- chmod +x ./utr
- ./utr --suite=editor --platform=StandaloneOSX --editor-location=.Editor --testproject=DevProject --artifacts_path=build/test-results --report-performance-data --performance-project-id=com.unity.ml-agents --zero-tests-are-ok=1
triggers:

1
.yamato/gym-interface-test.yml


- |
sudo apt-get update && sudo apt-get install -y python3-venv
python3 -m venv venv && source venv/bin/activate
python -m pip install wheel --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
python -u -m ml-agents.tests.yamato.setup_venv
python ml-agents/tests/yamato/scripts/run_gym.py --env=artifacts/testPlayer-Basic

16
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs


using System;
using Unity.MLAgents.Actuators;
using UnityEngine;
namespace Unity.MLAgentsExamples
{

/// <summary>
/// Simple actuator that converts the action into a {-1, 0, 1} direction
/// </summary>
public class BasicActuator : IActuator
public class BasicActuator : IActuator, IHeuristicProvider
{
public BasicController basicController;
ActionSpec m_ActionSpec;

}
basicController.MoveDirection(direction);
}
public void Heuristic(in ActionBuffers actionBuffersOut)
{
var direction = Input.GetAxis("Horizontal");
var discreteActions = actionBuffersOut.DiscreteActions;
if (Mathf.Approximately(direction, 0.0f))
{
discreteActions[0] = 0;
return;
}
var sign = Math.Sign(direction);
discreteActions[0] = sign < 0 ? 1 : 2;
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

3
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using UnityEngine.Rendering;
using UnityEngine.Serialization;
public class GridAgent : Agent

void WaitTimeInference()
{
if (renderCamera != null)
if (renderCamera != null && SystemInfo.graphicsDeviceType != GraphicsDeviceType.Null)
{
renderCamera.Render();
}

25
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab


- component: {fileID: 3508723250470608012}
- component: {fileID: 3508723250470608011}
- component: {fileID: 3508723250470608009}
- component: {fileID: 3508723250470608013}
- component: {fileID: 2112317463290853299}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged

m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
m_ActionSpec:
m_NumContinuousActions: 0
BranchSizes:
hasUpgradedBrainParametersWithActionSpec: 1
m_Model: {fileID: 11400000, guid: c34da50737a3c4a50918002b20b2b927, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0

Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
HeuristicQuality: 0
--- !u!114 &3508723250470608011
MonoBehaviour:
m_ObjectHideFlags: 0

m_EditorClassIdentifier:
DebugMoveIndex: -1
CubeSpacing: 1.25
Board: {fileID: 0}
TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e,
type: 3}
--- !u!114 &3508723250470608009

BasicCellPoints: 1
SpecialCell1Points: 2
SpecialCell2Points: 3
--- !u!114 &3508723250470608013
--- !u!114 &3508723250470608014
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}

m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
ActuatorName: Match3 Actuator
ForceHeuristic: 1
--- !u!114 &3508723250470608014
SensorName: Match3 Sensor
ObservationType: 0
--- !u!114 &2112317463290853299
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}

m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3}
SensorName: Match3 Sensor
ObservationType: 0
ActuatorName: Match3 Actuator
ForceHeuristic: 1
HeuristicQuality: 0
--- !u!1 &3508723250774301855
GameObject:
m_ObjectHideFlags: 0

25
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab


- component: {fileID: 2118285884327540682}
- component: {fileID: 2118285884327540685}
- component: {fileID: 2118285884327540687}
- component: {fileID: 2118285884327540683}
- component: {fileID: 3357012711826686276}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged

m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
m_ActionSpec:
m_NumContinuousActions: 0
BranchSizes:
hasUpgradedBrainParametersWithActionSpec: 1
m_Model: {fileID: 11400000, guid: 9e89b8e81974148d3b7213530d00589d, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0

Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
HeuristicQuality: 0
--- !u!114 &2118285884327540685
MonoBehaviour:
m_ObjectHideFlags: 0

m_EditorClassIdentifier:
DebugMoveIndex: -1
CubeSpacing: 1.25
Board: {fileID: 0}
TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e,
type: 3}
--- !u!114 &2118285884327540687

BasicCellPoints: 1
SpecialCell1Points: 2
SpecialCell2Points: 3
--- !u!114 &2118285884327540683
--- !u!114 &2118285884327540680
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}

m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
ActuatorName: Match3 Actuator
ForceHeuristic: 0
--- !u!114 &2118285884327540680
SensorName: Match3 Sensor
ObservationType: 0
--- !u!114 &3357012711826686276
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}

m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3}
SensorName: Match3 Sensor
ObservationType: 0
ActuatorName: Match3 Actuator
ForceHeuristic: 0
HeuristicQuality: 0

25
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab


- component: {fileID: 3019509692332007781}
- component: {fileID: 3019509692332007778}
- component: {fileID: 3019509692332007776}
- component: {fileID: 3019509692332007780}
- component: {fileID: 8270768986451624427}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged

m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
m_ActionSpec:
m_NumContinuousActions: 0
BranchSizes:
hasUpgradedBrainParametersWithActionSpec: 1
m_Model: {fileID: 11400000, guid: 48d14da88fea74d0693c691c6e3f2e34, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0

Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
HeuristicQuality: 0
--- !u!114 &3019509692332007778
MonoBehaviour:
m_ObjectHideFlags: 0

m_EditorClassIdentifier:
DebugMoveIndex: -1
CubeSpacing: 1.25
Board: {fileID: 0}
TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e,
type: 3}
--- !u!114 &3019509692332007776

BasicCellPoints: 1
SpecialCell1Points: 2
SpecialCell2Points: 3
--- !u!114 &3019509692332007780
--- !u!114 &3019509692332007783
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}

m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
ActuatorName: Match3 Actuator
ForceHeuristic: 0
--- !u!114 &3019509692332007783
SensorName: Match3 Sensor
ObservationType: 2
--- !u!114 &8270768986451624427
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}

m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3}
SensorName: Match3 Sensor
ObservationType: 2
ActuatorName: Match3 Actuator
ForceHeuristic: 0
HeuristicQuality: 0

10
Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity


m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 2112317463290853299, guid: 2fafdcd0587684641b03b11f04454f1b,
type: 3}
propertyPath: HeuristicQuality
value: 1
objectReference: {fileID: 0}
- target: {fileID: 3508723250470608011, guid: 2fafdcd0587684641b03b11f04454f1b,
type: 3}
propertyPath: cubeSpacing

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 2112317463290853299, guid: 2fafdcd0587684641b03b11f04454f1b,
type: 3}
propertyPath: HeuristicQuality
value: 1
objectReference: {fileID: 0}
- target: {fileID: 3508723250470608011, guid: 2fafdcd0587684641b03b11f04454f1b,
type: 3}
propertyPath: cubeSpacing

166
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs


WaitForMove = 4,
}
public enum HeuristicQuality
{
/// <summary>
/// The heuristic will pick any valid move at random.
/// </summary>
RandomValidMove,
/// <summary>
/// The heuristic will pick the move that scores the most points.
/// This only looks at the immediate move, and doesn't consider where cells will fall.
/// </summary>
Greedy
}
public class Match3Agent : Agent
{
[HideInInspector]

public int MaxMoves = 500;
public HeuristicQuality HeuristicQuality = HeuristicQuality.RandomValidMove;
private System.Random m_Random;
var seed = Board.RandomSeed == -1 ? gameObject.GetInstanceID() : Board.RandomSeed + 1;
m_Random = new System.Random(seed);
}
public override void OnEpisodeBegin()

return false;
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;
discreteActions[0] = GreedyMove();
}
int GreedyMove()
{
var pointsByType = new[] { Board.BasicCellPoints, Board.SpecialCell1Points, Board.SpecialCell2Points };
var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;
foreach (var move in Board.ValidMoves())
{
var movePoints = HeuristicQuality == HeuristicQuality.Greedy ? EvalMovePoints(move, pointsByType) : 1;
if (movePoints < bestMovePoints)
{
// Worse, skip
continue;
}
if (movePoints > bestMovePoints)
{
// Better, keep
bestMovePoints = movePoints;
bestMoveIndex = move.MoveIndex;
numMovesAtCurrentScore = 1;
}
else
{
// Tied for best - use reservoir sampling to make sure we select from equal moves uniformly.
// See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
numMovesAtCurrentScore++;
var randVal = m_Random.Next(0, numMovesAtCurrentScore);
if (randVal == 0)
{
// Keep the new one
bestMoveIndex = move.MoveIndex;
}
}
}
return bestMoveIndex;
}
int EvalMovePoints(Move move, int[] pointsByType)
{
// Counts the expected points for making the move.
var moveVal = Board.GetCellType(move.Row, move.Column);
var moveSpecial = Board.GetSpecialType(move.Row, move.Column);
var (otherRow, otherCol) = move.OtherCell();
var oppositeVal = Board.GetCellType(otherRow, otherCol);
var oppositeSpecial = Board.GetSpecialType(otherRow, otherCol);
int movePoints = EvalHalfMove(
otherRow, otherCol, moveVal, moveSpecial, move.Direction, pointsByType
);
int otherPoints = EvalHalfMove(
move.Row, move.Column, oppositeVal, oppositeSpecial, move.OtherDirection(), pointsByType
);
return movePoints + otherPoints;
}
int EvalHalfMove(int newRow, int newCol, int newValue, int newSpecial, Direction incomingDirection, int[] pointsByType)
{
// This is a essentially a duplicate of AbstractBoard.CheckHalfMove but also counts the points for the move.
int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0;
int scoreLeft = 0, scoreRight = 0, scoreUp = 0, scoreDown = 0;
if (incomingDirection != Direction.Right)
{
for (var c = newCol - 1; c >= 0; c--)
{
if (Board.GetCellType(newRow, c) == newValue)
{
matchedLeft++;
scoreLeft += pointsByType[Board.GetSpecialType(newRow, c)];
}
else
break;
}
}
if (incomingDirection != Direction.Left)
{
for (var c = newCol + 1; c < Board.Columns; c++)
{
if (Board.GetCellType(newRow, c) == newValue)
{
matchedRight++;
scoreRight += pointsByType[Board.GetSpecialType(newRow, c)];
}
else
break;
}
}
if (incomingDirection != Direction.Down)
{
for (var r = newRow + 1; r < Board.Rows; r++)
{
if (Board.GetCellType(r, newCol) == newValue)
{
matchedUp++;
scoreUp += pointsByType[Board.GetSpecialType(r, newCol)];
}
else
break;
}
}
if (incomingDirection != Direction.Up)
{
for (var r = newRow - 1; r >= 0; r--)
{
if (Board.GetCellType(r, newCol) == newValue)
{
matchedDown++;
scoreDown += pointsByType[Board.GetSpecialType(r, newCol)];
}
else
break;
}
}
if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2))
{
// It's a match. Start from counting the piece being moved
var totalScore = pointsByType[newSpecial];
if (matchedUp + matchedDown >= 2)
{
totalScore += scoreUp + scoreDown;
}
if (matchedLeft + matchedRight >= 2)
{
totalScore += scoreLeft + scoreRight;
}
return totalScore;
}
return 0;
}
}
}

13
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs


using System;
using Unity.MLAgents.Extensions.Match3;
using UnityEngine;

public class Match3Board : AbstractBoard
{
public int RandomSeed = -1;
public const int k_EmptyCell = -1;
[Tooltip("Points earned for clearing a basic cell (cube)")]
public int BasicCellPoints = 1;

[Tooltip("Points earned for clearing an extra special cell (plus)")]
public int SpecialCell2Points = 3;
/// <summary>
/// Seed to initialize the <see cref="System.Random"/> object.
/// </summary>
public int RandomSeed;
(int, int)[,] m_Cells;
bool[,] m_Matched;

m_Cells = new (int, int)[Columns, Rows];
m_Matched = new bool[Columns, Rows];
}
void Start()
{
InitRandom();
}

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


float[] buffer = new float[numFloats];
WriteObservation(buffer);
writer.AddRange(buffer);
writer.AddList(buffer);
return numFloats;
}

2
Project/ProjectSettings/UnityConnectSettings.asset


UnityConnectSettings:
m_ObjectHideFlags: 0
serializedVersion: 1
m_Enabled: 1
m_Enabled: 0
m_TestMode: 0
m_EventOldUrl: https://api.uca.cloud.unity3d.com/v1/events
m_EventUrl: https://cdp.cloud.unity3d.com/v1/events

51
README.md


- The **Documentation** links in the table below include installation and usage
instructions specific to each release. Remember to always use the
documentation that corresponds to the release version you're using.
| **Version** | **Release Date** | **Source** | **Documentation** | **Download** |
|:-------:|:------:|:-------------:|:-------:|:------------:|
| **master (unstable)** | -- | [source](https://github.com/Unity-Technologies/ml-agents/tree/master) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/master/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/master.zip) |
| **Release 12** | **December 22, 2020** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_12)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_12_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_12.zip)** |
| **Release 11** | December 21, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_11) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_11_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_11.zip) |
| **Release 10** | November 18, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_10) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_10_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_10.zip) |
| **Release 9** | November 4, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_9) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_9_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_9.zip) |
| **Release 8** | October 14, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_8) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_8.zip) |
| **Release 7** | September 16, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_7) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_7_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_7.zip) |
| **Release 6** | August 12, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_6) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_6_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_6.zip) |
| **Release 5** | July 31, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_5) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_5_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_5.zip) |
- The `com.unity.ml-agents` package is [verified](https://docs.unity3d.com/2020.1/Documentation/Manual/pack-safe.html)
for Unity 2020.1 and later. Verified packages releases are numbered 1.0.x.
## Citation
| **Version** | **Release Date** | **Source** | **Documentation** | **Download** | **Python Package** | **Unity Package** |
|:-------:|:------:|:-------------:|:-------:|:------------:|:------------:|:------------:|
| **master (unstable)** | -- | [source](https://github.com/Unity-Technologies/ml-agents/tree/master) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/master/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/master.zip) | -- | -- |
| **Release 12** | **December 22, 2020** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_12)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_12_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_12.zip)** | **[0.23.0](https://pypi.org/project/mlagents/0.23.0/)** | **[1.7.2](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.7/manual/index.html)** |
| **Release 11** | December 21, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_11) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_11_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_11.zip) | [0.23.0](https://pypi.org/project/mlagents/0.23.0/) | [1.7.0](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.7/manual/index.html) |
| **Release 10** | November 18, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_10) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_10_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_10.zip) | [0.22.0](https://pypi.org/project/mlagents/0.22.0/) | [1.6.0](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.6/manual/index.html) |
| **Verified Package 1.0.6** | **November 16, 2020** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/com.unity.ml-agents_1.0.6)** | **[docs](https://github.com/Unity-Technologies/ml-agents/blob/release_2_verified_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/com.unity.ml-agents_1.0.6.zip)** | **[0.16.1](https://pypi.org/project/mlagents/0.16.1/)** | **[1.0.6](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.0/manual/index.html)** |
| **Release 9** | November 4, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_9) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_9_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_9.zip) | [0.21.1](https://pypi.org/project/mlagents/0.21.1/) | [1.5.0](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.5/manual/index.html) |
| **Release 8** | October 14, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_8) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_8.zip) | [0.21.0](https://pypi.org/project/mlagents/0.21.0/) | [1.5.0](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.5/manual/index.html) |
| **Verified Package 1.0.5** | September 23, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/com.unity.ml-agents_1.0.5) | [docs](https://github.com/Unity-Technologies/ml-agents/blob/release_2_verified_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/com.unity.ml-agents_1.0.5.zip) | [0.16.1](https://pypi.org/project/mlagents/0.16.1/) | [1.0.5](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.0/manual/index.html) |
| **Release 7** | September 16, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_7) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_7_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_7.zip) | [0.20.0](https://pypi.org/project/mlagents/0.20.0/) | [1.4.0](https://docs.unity3d.com/Packages/com.unity.ml-agents@1.4/manual/index.html) |
If you are a researcher interested in a discussion of Unity as an AI platform,
see a pre-print of our
[reference paper on Unity and the ML-Agents Toolkit](https://arxiv.org/abs/1809.02627).

## Additional Resources
We have published a series of blog posts that are relevant for ML-Agents:
We have a Unity Learn course,
[ML-Agents: Hummingsbird](https://learn.unity.com/course/ml-agents-hummingbirds),
that provides a gentle introduction to Unity and the ML-Agents Toolkit.
We've also partnered with
[CodeMonkeyUnity](https://www.youtube.com/c/CodeMonkeyUnity) to create a
[series of tutorial videos](https://www.youtube.com/playlist?list=PLzDRvYVwl53vehwiN_odYJkPBzcqFw110)
on how to implement and use the ML-Agents Toolkit.
We have also published a series of blog posts that are relevant for ML-Agents:
- (December 28, 2020)
[Happy holidays from the Unity ML-Agents team!](https://blogs.unity3d.com/2020/12/28/happy-holidays-from-the-unity-ml-agents-team/)
- (November 20, 2020)
[How Eidos-Montréal created Grid Sensors to improve observations for training agents](https://blogs.unity3d.com/2020/11/20/how-eidos-montreal-created-grid-sensors-to-improve-observations-for-training-agents/)
- (November 11, 2020)
[2020 AI@Unity interns shoutout](https://blogs.unity3d.com/2020/11/11/2020-aiunity-interns-shoutout/)
- (May 12, 2020)
[Announcing ML-Agents Unity Package v1.0!](https://blogs.unity3d.com/2020/05/12/announcing-ml-agents-unity-package-v1-0/)
- (February 28, 2020)

([multi-armed bandit](https://blogs.unity3d.com/2017/06/26/unity-ai-themed-blog-entries/)
and
[Q-learning](https://blogs.unity3d.com/2017/08/22/unity-ai-reinforcement-learning-with-q-learning/))
In addition to our own documentation, here are some additional, relevant
articles:
- [A Game Developer Learns Machine Learning](https://mikecann.co.uk/posts/a-game-developer-learns-machine-learning-intent)
- [Explore Unity Technologies ML-Agents Exclusively on Intel Architecture](https://software.intel.com/en-us/articles/explore-unity-technologies-ml-agents-exclusively-on-intel-architecture)
- [ML-Agents Penguins tutorial](https://learn.unity.com/project/ml-agents-penguins)
## Community and Feedback

74
com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs


/// Actuator for a Match3 game. It translates valid moves (defined by AbstractBoard.IsMoveValid())
/// in action masks, and applies the action to the board via AbstractBoard.MakeMove().
/// </summary>
public class Match3Actuator : IActuator
public class Match3Actuator : IActuator, IHeuristicProvider
private AbstractBoard m_Board;
protected AbstractBoard m_Board;
protected System.Random m_Random;
private System.Random m_Random;
private Agent m_Agent;
private int m_Rows;

/// <param name="board"></param>
/// <param name="forceHeuristic">Whether the inference action should be ignored and the Agent's Heuristic
/// should be called. This should only be used for generating comparison stats of the Heuristic.</param>
/// <param name="seed">The seed used to initialize <see cref="System.Random"/>.</param>
public Match3Actuator(AbstractBoard board, bool forceHeuristic, Agent agent, string name)
public Match3Actuator(AbstractBoard board,
bool forceHeuristic,
int seed,
Agent agent,
string name)
{
m_Board = board;
m_Rows = board.Rows;

var numMoves = Move.NumPotentialMoves(m_Board.Rows, m_Board.Columns);
m_ActionSpec = ActionSpec.MakeDiscrete(numMoves);
m_Random = new System.Random(seed);
}
/// <inheritdoc/>

{
if (m_ForceHeuristic)
{
m_Agent.Heuristic(actions);
Heuristic(actions);
}
var moveIndex = actions.DiscreteActions[0];

yield return move.MoveIndex;
}
}
public void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;
discreteActions[0] = GreedyMove();
}
protected int GreedyMove()
{
var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;
foreach (var move in m_Board.ValidMoves())
{
var movePoints = EvalMovePoints(move);
if (movePoints < bestMovePoints)
{
// Worse, skip
continue;
}
if (movePoints > bestMovePoints)
{
// Better, keep
bestMovePoints = movePoints;
bestMoveIndex = move.MoveIndex;
numMovesAtCurrentScore = 1;
}
else
{
// Tied for best - use reservoir sampling to make sure we select from equal moves uniformly.
// See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
numMovesAtCurrentScore++;
var randVal = m_Random.Next(0, numMovesAtCurrentScore);
if (randVal == 0)
{
// Keep the new one
bestMoveIndex = move.MoveIndex;
}
}
}
return bestMoveIndex;
}
/// <summary>
/// Method to be overridden when evaluating how many points a specific move will generate.
/// </summary>
/// <param name="move">The move to evaluate.</param>
/// <returns>The number of points the move generates.</returns>
protected virtual int EvalMovePoints(Move move)
{
return 1;
}
}
}

10
com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs


namespace Unity.MLAgents.Extensions.Match3
{
/// <summary>
/// Actuator component for a Match 3 game. Generates a Match3Actuator at runtime.
/// Actuator component for a Match3 game. Generates a Match3Actuator at runtime.
/// </summary>
public class Match3ActuatorComponent : ActuatorComponent
{

public string ActuatorName = "Match3 Actuator";
/// <summary>
/// A random seed used to generate a board, if needed.
/// </summary>
public int RandomSeed = -1;
/// <summary>
/// Force using the Agent's Heuristic() method to decide the action. This should only be used in testing.
/// </summary>
[FormerlySerializedAs("ForceRandom")]

{
var board = GetComponent<AbstractBoard>();
var agent = GetComponentInParent<Agent>();
return new Match3Actuator(board, ForceHeuristic, agent, ActuatorName);
var seed = RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed + 1;
return new Match3Actuator(board, ForceHeuristic, seed, agent, ActuatorName);
}
/// <inheritdoc/>

8
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


/// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
/// and AbstractBoard.GetSpecialType() to determine the observation values.
/// </summary>
public class Match3Sensor : ISparseChannelSensor
public class Match3Sensor : ISparseChannelSensor, IBuiltInSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;

public int[] GetCompressedChannelMapping()
{
return m_SparseChannelMapping;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.Match3Sensor;
}
static void DestroyTexture(Texture2D texture)

9
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


/// <summary>
/// Grid-based sensor.
/// </summary>
public class GridSensor : SensorComponent, ISensor
public class GridSensor : SensorComponent, ISensor, IBuiltInSensor
{
/// <summary>
/// Name of this grid sensor.

{
return CompressionType;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.GridSensor;
}
/// <summary>
/// GetCompressedObservation - Calls Perceive then puts the data stored on the perception buffer

9
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


/// <summary>
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
/// </summary>
public class PhysicsBodySensor : ISensor
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
string m_SensorName;

{
return m_SensorName;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.PhysicsBodySensor;
}
}
}

25
com.unity.ml-agents.extensions/Runtime/Teams/BaseTeamManager.cs


using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
private readonly string m_Id = System.Guid.NewGuid().ToString();
readonly int m_Id = TeamManagerIdCounter.GetTeamManagerId();
public virtual void RegisterAgent(Agent agent)
{
}
public virtual void RegisterAgent(Agent agent) { }
public virtual void OnAgentDone(Agent agent, Agent.DoneReason doneReason, List<ISensor> sensors)
{
// Possible implementation - save reference to Agent's IPolicy so that we can repeatedly
// call IPolicy.RequestDecision on behalf of the Agent after it's dead
// If so, we'll need dummy sensor impls with the same shape as the originals.
agent.SendDoneToTrainer();
}
public virtual void AddTeamReward(float reward)
{
}
public string GetId()
public int GetId()
{
return m_Id;
}

34
com.unity.ml-agents/CHANGELOG.md


### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- The `ActionSpec` constructor is now public. Previously, it was not possible to create an
ActionSpec with both continuous and discrete actions from code. (#4896)
will result in the values being summed (instead of averaged) when written to
TensorBoard. Thanks to @brccabral for the contribution! (#4816)
will result in the values being summed (instead of averaged) when written to
TensorBoard. Thanks to @brccabral for the contribution! (#4816)
- The upper limit for the time scale (by setting the `--time-scale` paramater in mlagents-learn) was
removed when training with a player. The Editor still requires it to be clamped to 100. (#4867)
- Added the IHeuristicProvider interface to allow IActuators as well as Agent implement the Heuristic function to generate actions.
Updated the Basic example and the Match3 Example to use Actuators.
Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849)
- Added `VectorSensor.AddObservation(IList<float>)`. `VectorSensor.AddObservation(IEnumerable<float>)`
is deprecated. The `IList` version is recommended, as it does not generate any
additional memory allocations. (#4887)
- Added `ObservationWriter.AddList()` and deprecated `ObservationWriter.AddRange()`.
`AddList()` is recommended, as it does not generate any additional memory allocations. (#4887)
- Added a `--torch-device` commandline option to `mlagents-learn`, which sets the default
[`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device) used for training. (#4888)
- The `--cpu` commandline option had no effect and was removed. Use `--torch-device=cpu` to force CPU training. (#4888)
- CameraSensor now logs an error if the GraphicsDevice is null. (#4880)
- Removed unnecessary memory allocations in `ActuatorManager.UpdateActionArray()` (#4877)
- Removed unnecessary memory allocations in `SensorShapeValidator.ValidateSensors()` (#4879)
- Removed unnecessary memory allocations in `SideChannelManager.GetSideChannelMessage()` (#4886)
- Removed several memory allocations that happened during inference. On a test scene, this
reduced the amount of memory allocated by approximately 25%. (#4887)
- Fixed a bug that can cause a crash if a behavior can appear during training in multi-environment training. (#4872)
- Fixed the computation of entropy for continuous actions. (#4869)
- Fixed a bug that would cause `UnityEnvironment` to wait the full timeout
period and report a misleading error message if the executable crashed
without closing the connection. It now periodically checks the process status
while waiting for a connection, and raises a better error message if it crashes. (#4880)
- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is
no longer overwritten. (#4880)
## [1.7.2-preview] - 2020-12-22

7
com.unity.ml-agents/Runtime/Academy.cs


/// <term>1.3.0</term>
/// <description>Support both continuous and discrete actions.</description>
/// </item>
/// <item>
/// <term>1.4.0</term>
/// <description>Support training analytics sent from python trainer to the editor.</description>
/// </item>
const string k_ApiVersion = "1.3.0";
const string k_ApiVersion = "1.4.0";
/// <summary>
/// Unity package version of com.unity.ml-agents.

EnableAutomaticStepping();
SideChannelManager.RegisterSideChannel(new EngineConfigurationChannel());
SideChannelManager.RegisterSideChannel(new TrainingAnalyticsSideChannel());
m_EnvironmentParameters = new EnvironmentParameters();
m_StatsRecorder = new StatsRecorder();

13
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


/// <summary>
/// Creates a Continuous <see cref="ActionSpec"/> with the number of actions available.
/// </summary>
/// <param name="numActions">The number of actions available.</param>
/// <param name="numActions">The number of continuous actions available.</param>
/// <returns>An Continuous ActionSpec initialized with the number of actions available.</returns>
public static ActionSpec MakeContinuous(int numActions)
{

return actuatorSpace;
}
internal ActionSpec(int numContinuousActions, int[] branchSizes = null)
/// <summary>
/// Create an ActionSpec initialized with the specified action sizes.
/// </summary>
/// <param name="numContinuousActions">The number of continuous actions available.</param>
/// <param name="discreteBranchSizes">The array of branch sizes for the discrete actions. Each index
/// contains the number of actions available for that branch.</param>
/// <returns>An ActionSpec initialized with the specified action sizes.</returns>
public ActionSpec(int numContinuousActions = 0, int[] discreteBranchSizes = null)
BranchSizes = branchSizes;
BranchSizes = discreteBranchSizes;
}
/// <summary>

50
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


}
else
{
Debug.Assert(sourceActionBuffer.Length == destination.Length,
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" +
$" size than destination: {destination.Length}.");
Debug.AssertFormat(sourceActionBuffer.Length == destination.Length,
"sourceActionBuffer: {0} is a different size than destination: {1}.",
sourceActionBuffer.Length,
destination.Length);
Array.Copy(sourceActionBuffer.Array,
sourceActionBuffer.Offset,

actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
}
/// <summary>
/// Iterates through all of the IActuators in this list and calls their
/// <see cref="IHeuristicProvider.Heuristic"/> method on them, if implemented, with the appropriate
/// <see cref="ActionSegment{T}"/>s depending on their <see cref="ActionSpec"/>.
/// </summary>
public void ApplyHeuristic(in ActionBuffers actionBuffersOut)
{
var continuousStart = 0;
var discreteStart = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
if (numContinuousActions == 0 && numDiscreteActions == 0)
{
continue;
}
var continuousActions = ActionSegment<float>.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment<float>(actionBuffersOut.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}
var discreteActions = ActionSegment<int>.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment<int>(actionBuffersOut.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}
var heuristic = actuator as IHeuristicProvider;
heuristic?.Heuristic(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}
}

27
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs


namespace Unity.MLAgents.Actuators
{
/// <summary>
/// IActuator implementation that forwards to an <see cref="IActionReceiver"/>.
/// IActuator implementation that forwards calls to an <see cref="IActionReceiver"/> and an <see cref="IHeuristicProvider"/>.
internal class VectorActuator : IActuator
internal class VectorActuator : IActuator, IHeuristicProvider
IHeuristicProvider m_HeuristicProvider;
ActionBuffers m_ActionBuffers;
internal ActionBuffers ActionBuffers

/// <summary>
/// Create a VectorActuator that forwards to the provided IActionReceiver.
/// </summary>
/// <param name="actionReceiver">The <see cref="IActionReceiver"/> used for OnActionReceived and WriteDiscreteActionMask.
/// If this parameter also implements <see cref="IHeuristicProvider"/> it will be cast and used to forward calls to
/// <see cref="IHeuristicProvider.Heuristic"/>.</param>
/// <param name="actionSpec"></param>
/// <param name="name"></param>
public VectorActuator(IActionReceiver actionReceiver,
ActionSpec actionSpec,
string name = "VectorActuator")
: this(actionReceiver, actionReceiver as IHeuristicProvider, actionSpec, name) { }
/// <summary>
/// Create a VectorActuator that forwards to the provided IActionReceiver.
/// </summary>
/// <param name="heuristicProvider">The <see cref="IHeuristicProvider"/> used to fill the <see cref="ActionBuffers"/>
/// for Heuristic Policies.</param>
IHeuristicProvider heuristicProvider,
m_HeuristicProvider = heuristicProvider;
ActionSpec = actionSpec;
string suffix;
if (actionSpec.NumContinuousActions == 0)

{
ActionBuffers = actionBuffers;
m_ActionReceiver.OnActionReceived(ActionBuffers);
}
public void Heuristic(in ActionBuffers actionBuffersOut)
{
m_HeuristicProvider?.Heuristic(actionBuffersOut);
}
/// <inheritdoc />

50
com.unity.ml-agents/Runtime/Agent.cs


/// <summary>
/// Team Manager identifier.
/// </summary>
public string teamManagerId;
public int teamManagerId;
public void ClearActions()
{

"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]
public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver
public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver, IHeuristicProvider
{
IPolicy m_Brain;
BehaviorParameters m_PolicyFactory;

private ITeamManager m_TeamManager;
/// <summary>
/// This is used to avoid allocation of a float array during legacy calls to Heuristic.
/// </summary>
float[] m_LegacyHeuristicCache;
ITeamManager m_TeamManager;
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
/// </summary>

InitializeActuators();
}
m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic);
m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager);
ResetData();
Initialize();

new int[m_ActuatorManager.NumDiscreteActions]
);
if (m_TeamManager != null)
{
m_Info.teamManagerId = m_TeamManager.GetId();
}
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

m_Info.reward = m_Reward;
m_Info.done = true;
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
if (collectObservationsSensor != null)
{
// Make sure the latest observations are being passed to training.

return;
}
m_Brain?.Dispose();
m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic);
m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager);
}
/// <summary>

public virtual void Initialize() { }
/// <summary>
/// Implement `Heuristic()` to choose an action for this agent using a custom heuristic.
/// Implement <see cref="Heuristic"/> to choose an action for this agent using a custom heuristic.
/// control of an agent using keyboard, mouse, or game controller input.
/// control of an agent using keyboard, mouse, game controller input, or a script.
///
/// Your heuristic implementation can use any decision making logic you specify. Assign decision
/// values to the <see cref="ActionBuffers.ContinuousActions"/> and <see cref="ActionBuffers.DiscreteActions"/>

switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType)
{
case SpaceType.Continuous:
Heuristic(actionsOut.ContinuousActions.Array);
Heuristic(m_LegacyHeuristicCache);
Array.Copy(m_LegacyHeuristicCache, actionsOut.ContinuousActions.Array, m_LegacyActionCache.Length);
var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x);
Heuristic(convertedOut);
Heuristic(m_LegacyHeuristicCache);
discreteActionSegment[i] = (int)convertedOut[i];
discreteActionSegment[i] = (int)m_LegacyHeuristicCache[i];
}
/// <summary>

// Support legacy OnActionReceived
// TODO don't set this up if the sizes are 0?
var param = m_PolicyFactory.BrainParameters;
m_VectorActuator = new VectorActuator(this, param.ActionSpec);
m_VectorActuator = new VectorActuator(this, this, param.ActionSpec);
m_LegacyHeuristicCache = new float[m_VectorActuator.TotalNumberOfActions()];
m_ActuatorManager.Add(m_VectorActuator);

m_Info.done = false;
m_Info.maxStepReached = false;
m_Info.episodeId = m_EpisodeId;
m_Info.teamManagerId = m_TeamManager == null ? -1 : m_TeamManager.GetId();
using (TimerStack.Instance.Scoped("RequestDecision"))
{

/// three values in ActionBuffers.ContinuousActions array to use as the force components.
/// During training, the agent's policy learns to set those particular elements of
/// the array to maximize the training rewards the agent receives. (Of course,
/// if you implement a <seealso cref="Heuristic(in ActionBuffers)"/> function, it must use the same
/// if you implement a <seealso cref="Agent.Heuristic(in ActionBuffers)"/> function, it must use the same
/// elements of the action array for the same purpose since there is no learning
/// involved.)
///

if (!actions.ContinuousActions.IsEmpty())
{
m_LegacyActionCache = actions.ContinuousActions.Array;
Array.Copy(actions.ContinuousActions.Array,
m_LegacyActionCache,
actionSpec.NumContinuousActions);
m_LegacyActionCache = Array.ConvertAll(actions.DiscreteActions.Array, x => (float)x);
for (var i = 0; i < m_LegacyActionCache.Length; i++)
{
m_LegacyActionCache[i] = (float)actions.DiscreteActions[i];
}
}
// Disable deprecation warnings so we can call the legacy overload.
#pragma warning disable CS0618

public void SetTeamManager(ITeamManager teamManager)
{
m_TeamManager = teamManager;
m_Info.teamManagerId = teamManager?.GetId();
teamManager?.RegisterAgent(this);
}
}

68
com.unity.ml-agents/Runtime/Analytics/Events.cs


{
public string SensorName;
public string CompressionType;
public int BuiltInSensorType;
public EventObservationDimensionInfo[] DimensionInfos;
public static EventObservationSpec FromSensor(ISensor sensor)

// TODO copy flags when we have them
}
var builtInSensorType =
(sensor as IBuiltInSensor)?.GetBuiltInSensorType() ?? Sensors.BuiltInSensorType.Unknown;
BuiltInSensorType = (int)builtInSensorType,
}
internal struct RemotePolicyInitializedEvent
{
public string TrainingSessionGuid;
/// <summary>
/// Hash of the BehaviorName.
/// </summary>
public string BehaviorName;
public List<EventObservationSpec> ObservationSpecs;
public EventActionSpec ActionSpec;
/// <summary>
/// This will be the same as TrainingEnvironmentInitializedEvent if available, but
/// TrainingEnvironmentInitializedEvent maybe not always be available with older trainers.
/// </summary>
public string MLAgentsEnvsVersion;
public string TrainerCommunicationVersion;
}
internal struct TrainingEnvironmentInitializedEvent
{
public string TrainingSessionGuid;
public string TrainerPythonVersion;
public string MLAgentsVersion;
public string MLAgentsEnvsVersion;
public string TorchVersion;
public string TorchDeviceType;
public int NumEnvironments;
public int NumEnvironmentParameters;
}
[Flags]
internal enum RewardSignals
{
Extrinsic = 1 << 0,
Gail = 1 << 1,
Curiosity = 1 << 2,
Rnd = 1 << 3,
}
[Flags]
internal enum TrainingFeatures
{
BehavioralCloning = 1 << 0,
Recurrent = 1 << 1,
Threaded = 1 << 2,
SelfPlay = 1 << 3,
Curriculum = 1 << 4,
}
internal struct TrainingBehaviorInitializedEvent
{
public string TrainingSessionGuid;
public string BehaviorName;
public string TrainerType;
public RewardSignals RewardSignalFlags;
public TrainingFeatures TrainingFeatureFlags;
public string VisualEncoder;
public int NumNetworkLayers;
public int NumNetworkHiddenUnits;
}
}

14
com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs


{
const string k_VendorKey = "unity.ml-agents";
const string k_EventName = "ml_agents_inferencemodelset";
const int k_EventVersion = 1;
/// <summary>
/// Whether or not we've registered this particular event yet

/// </summary>
const int k_MaxNumberOfElements = 1000;
/// <summary>
/// Models that we've already sent events for.
/// </summary>

}
#if UNITY_EDITOR
AnalyticsResult result = EditorAnalytics.RegisterEventWithLimit(k_EventName, k_MaxEventsPerHour, k_MaxNumberOfElements, k_VendorKey);
AnalyticsResult result = EditorAnalytics.RegisterEventWithLimit(k_EventName, k_MaxEventsPerHour, k_MaxNumberOfElements, k_VendorKey, k_EventVersion);
#else
AnalyticsResult result = AnalyticsResult.UnsupportedPlatform;
#endif

var data = GetEventForModel(nnModel, behaviorName, inferenceDevice, sensors, actionSpec);
// Note - to debug, use JsonUtility.ToJson on the event.
// Debug.Log(JsonUtility.ToJson(data, true));
//Debug.Log(JsonUtility.ToJson(data, true));
EditorAnalytics.SendEventWithLimit(k_EventName, data);
if (AnalyticsUtils.s_SendEditorAnalytics)
{
EditorAnalytics.SendEventWithLimit(k_EventName, data, k_EventVersion);
}
#else
return;
#endif

var inferenceEvent = new InferenceEvent();
// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
var behaviorNameHash = Hash128.Compute(behaviorName);
inferenceEvent.BehaviorName = behaviorNameHash.ToString();
inferenceEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
inferenceEvent.BarracudaModelSource = barracudaModel.IrSource;
inferenceEvent.BarracudaModelVersion = barracudaModel.IrVersion;

64
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


using UnityEngine;
using System.Runtime.CompilerServices;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Analytics;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;

MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.episodeId,
TeamManagerId = ai.teamManagerId,
};
if (ai.discreteActionMasks != null)

}
}
observationProto.Shape.AddRange(shape);
// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;
if (typeSensor != null)
{
observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
}
else
{
observationProto.ObservationType = ObservationTypeProto.Default;
}
return observationProto;
}

ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
CompressedChannelMapping = proto.CompressedChannelMapping,
HybridActions = proto.HybridActions,
TrainingAnalytics = proto.TrainingAnalytics,
};
}

ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
HybridActions = rlCaps.HybridActions,
TrainingAnalytics = rlCaps.TrainingAnalytics,
};
}

}
return true;
}
#region Analytics
internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent(
this TrainingEnvironmentInitialized inputProto)
{
return new TrainingEnvironmentInitializedEvent
{
TrainerPythonVersion = inputProto.PythonVersion,
MLAgentsVersion = inputProto.MlagentsVersion,
MLAgentsEnvsVersion = inputProto.MlagentsEnvsVersion,
TorchVersion = inputProto.TorchVersion,
TorchDeviceType = inputProto.TorchDeviceType,
NumEnvironments = inputProto.NumEnvs,
NumEnvironmentParameters = inputProto.NumEnvironmentParameters,
};
}
internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEvent(
this TrainingBehaviorInitialized inputProto)
{
RewardSignals rewardSignals = 0;
rewardSignals |= inputProto.ExtrinsicRewardEnabled ? RewardSignals.Extrinsic : 0;
rewardSignals |= inputProto.GailRewardEnabled ? RewardSignals.Gail : 0;
rewardSignals |= inputProto.CuriosityRewardEnabled ? RewardSignals.Curiosity : 0;
rewardSignals |= inputProto.RndRewardEnabled ? RewardSignals.Rnd : 0;
TrainingFeatures trainingFeatures = 0;
trainingFeatures |= inputProto.BehavioralCloningEnabled ? TrainingFeatures.BehavioralCloning : 0;
trainingFeatures |= inputProto.RecurrentEnabled ? TrainingFeatures.Recurrent : 0;
trainingFeatures |= inputProto.TrainerThreaded ? TrainingFeatures.Threaded : 0;
trainingFeatures |= inputProto.SelfPlayEnabled ? TrainingFeatures.SelfPlay : 0;
trainingFeatures |= inputProto.CurriculumEnabled ? TrainingFeatures.Curriculum : 0;
return new TrainingBehaviorInitializedEvent
{
BehaviorName = inputProto.BehaviorName,
TrainerType = inputProto.TrainerType,
RewardSignalFlags = rewardSignals,
TrainingFeatureFlags = trainingFeatures,
VisualEncoder = inputProto.VisualEncoder,
NumNetworkLayers = inputProto.NumNetworkLayers,
NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits,
};
}
#endregion
}
}

5
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


using System.Linq;
using UnityEngine;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Analytics;
using Unity.MLAgents.CommunicatorObjects;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.SideChannels;

},
out input);
var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion;
var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion;
TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion);
var communicationIsCompatible = CheckCommunicationVersionsAreCompatible(unityCommunicationVersion,
pythonCommunicationVersion,

5
com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs


public bool ConcatenatedPngObservations;
public bool CompressedChannelMapping;
public bool HybridActions;
public bool TrainingAnalytics;
/// <summary>
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This

bool baseRlCapabilities = true,
bool concatenatedPngObservations = true,
bool compressedChannelMapping = true,
bool hybridActions = true)
bool hybridActions = true,
bool trainingAnalytics = true)
TrainingAnalytics = trainingAnalytics;
}
/// <summary>

26
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs


"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG8SFwoPdGVhbV9tYW5hZ2Vy",
"X2lkGA4gASgJSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"X2lkGA4gASgFSgQIARACSgQIAhADSgQIAxAESgQIBBAFSgQIBRAGSgQIBhAH",
"SgQIDBANQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3Rz",
"YgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,

/// <summary>Field number for the "team_manager_id" field.</summary>
public const int TeamManagerIdFieldNumber = 14;
private string teamManagerId_ = "";
private int teamManagerId_;
public string TeamManagerId {
public int TeamManagerId {
teamManagerId_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
teamManagerId_ = value;
}
}

if (Id != 0) hash ^= Id.GetHashCode();
hash ^= actionMask_.GetHashCode();
hash ^= observations_.GetHashCode();
if (TeamManagerId.Length != 0) hash ^= TeamManagerId.GetHashCode();
if (TeamManagerId != 0) hash ^= TeamManagerId.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

}
actionMask_.WriteTo(output, _repeated_actionMask_codec);
observations_.WriteTo(output, _repeated_observations_codec);
if (TeamManagerId.Length != 0) {
output.WriteRawTag(114);
output.WriteString(TeamManagerId);
if (TeamManagerId != 0) {
output.WriteRawTag(112);
output.WriteInt32(TeamManagerId);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);

}
size += actionMask_.CalculateSize(_repeated_actionMask_codec);
size += observations_.CalculateSize(_repeated_observations_codec);
if (TeamManagerId.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(TeamManagerId);
if (TeamManagerId != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(TeamManagerId);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();

}
actionMask_.Add(other.actionMask_);
observations_.Add(other.observations_);
if (other.TeamManagerId.Length != 0) {
if (other.TeamManagerId != 0) {
TeamManagerId = other.TeamManagerId;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);

observations_.AddEntriesFrom(input, _repeated_observations_codec);
break;
}
case 114: {
TeamManagerId = input.ReadString();
case 112: {
TeamManagerId = input.ReadInt32();
break;
}
}

39
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMilAEKGFVuaXR5UkxD",
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMirwEKGFVuaXR5UkxD",
"ASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZw",
"cm90bzM="));
"ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIQiWqAiJVbml0eS5NTEFn",
"ZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics" }, null, null, null)
}));
}
#endregion

concatenatedPngObservations_ = other.concatenatedPngObservations_;
compressedChannelMapping_ = other.compressedChannelMapping_;
hybridActions_ = other.hybridActions_;
trainingAnalytics_ = other.trainingAnalytics_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

}
}
/// <summary>Field number for the "trainingAnalytics" field.</summary>
public const int TrainingAnalyticsFieldNumber = 5;
private bool trainingAnalytics_;
/// <summary>
/// support for training analytics
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool TrainingAnalytics {
get { return trainingAnalytics_; }
set {
trainingAnalytics_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);

if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
if (HybridActions != other.HybridActions) return false;
if (TrainingAnalytics != other.TrainingAnalytics) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteRawTag(32);
output.WriteBool(HybridActions);
}
if (TrainingAnalytics != false) {
output.WriteRawTag(40);
output.WriteBool(TrainingAnalytics);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

if (HybridActions != false) {
size += 1 + 1;
}
if (TrainingAnalytics != false) {
size += 1 + 1;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

}
if (other.HybridActions != false) {
HybridActions = other.HybridActions;
}
if (other.TrainingAnalytics != false) {
TrainingAnalytics = other.TrainingAnalytics;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

}
case 32: {
HybridActions = input.ReadBool();
break;
}
case 40: {
TrainingAnalytics = input.ReadBool();
break;
}
}

52
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjRtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyK7AgoQT2JzZXJ2YXRp",
"aW9uLnByb3RvEhRjb21tdW5pY2F0b3Jfb2JqZWN0cyKBAwoQT2JzZXJ2YXRp",
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUaGQoJRmxvYXREYXRh",
"EgwKBGRhdGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVz",
"c2lvblR5cGVQcm90bxIICgROT05FEAASBwoDUE5HEAFCJaoCIlVuaXR5Lk1M",
"QWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
"KAUSHAoUZGltZW5zaW9uX3Byb3BlcnRpZXMYBiADKAUSRAoQb2JzZXJ2YXRp",
"b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0",
"aW9uVHlwZVByb3RvGhkKCUZsb2F0RGF0YRIMCgRkYXRhGAEgAygCQhIKEG9i",
"c2VydmF0aW9uX2RhdGEqKQoUQ29tcHJlc3Npb25UeXBlUHJvdG8SCAoETk9O",
"RRAAEgcKA1BORxABKkYKFE9ic2VydmF0aW9uVHlwZVByb3RvEgsKB0RFRkFV",
"TFQQABIICgRHT0FMEAESCgoGUkVXQVJEEAISCwoHTUVTU0FHRRADQiWqAiJV",
"bml0eS5NTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Parser, new[]{ "Shape", "CompressionType", "CompressedData", "FloatData", "CompressedChannelMapping", "DimensionProperties", "ObservationType" }, new[]{ "ObservationData" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData), global::Unity.MLAgents.CommunicatorObjects.ObservationProto.Types.FloatData.Parser, new[]{ "Data" }, null, null, null)})
}));
}
#endregion

internal enum CompressionTypeProto {
[pbr::OriginalName("NONE")] None = 0,
[pbr::OriginalName("PNG")] Png = 1,
}
internal enum ObservationTypeProto {
[pbr::OriginalName("DEFAULT")] Default = 0,
[pbr::OriginalName("GOAL")] Goal = 1,
[pbr::OriginalName("REWARD")] Reward = 2,
[pbr::OriginalName("MESSAGE")] Message = 3,
}
#endregion

compressionType_ = other.compressionType_;
compressedChannelMapping_ = other.compressedChannelMapping_.Clone();
dimensionProperties_ = other.dimensionProperties_.Clone();
observationType_ = other.observationType_;
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

get { return dimensionProperties_; }
}
/// <summary>Field number for the "observation_type" field.</summary>
public const int ObservationTypeFieldNumber = 7;
private global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto observationType_ = 0;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto ObservationType {
get { return observationType_; }
set {
observationType_ = value;
}
}
private object observationData_;
/// <summary>Enum of possible cases for the "observation_data" oneof.</summary>
public enum ObservationDataOneofCase {

if (!object.Equals(FloatData, other.FloatData)) return false;
if(!compressedChannelMapping_.Equals(other.compressedChannelMapping_)) return false;
if(!dimensionProperties_.Equals(other.dimensionProperties_)) return false;
if (ObservationType != other.ObservationType) return false;
if (ObservationDataCase != other.ObservationDataCase) return false;
return Equals(_unknownFields, other._unknownFields);
}

if (observationDataCase_ == ObservationDataOneofCase.FloatData) hash ^= FloatData.GetHashCode();
hash ^= compressedChannelMapping_.GetHashCode();
hash ^= dimensionProperties_.GetHashCode();
if (ObservationType != 0) hash ^= ObservationType.GetHashCode();
hash ^= (int) observationDataCase_;
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();

}
compressedChannelMapping_.WriteTo(output, _repeated_compressedChannelMapping_codec);
dimensionProperties_.WriteTo(output, _repeated_dimensionProperties_codec);
if (ObservationType != 0) {
output.WriteRawTag(56);
output.WriteEnum((int) ObservationType);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

}
size += compressedChannelMapping_.CalculateSize(_repeated_compressedChannelMapping_codec);
size += dimensionProperties_.CalculateSize(_repeated_dimensionProperties_codec);
if (ObservationType != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ObservationType);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

}
compressedChannelMapping_.Add(other.compressedChannelMapping_);
dimensionProperties_.Add(other.dimensionProperties_);
if (other.ObservationType != 0) {
ObservationType = other.ObservationType;
}
switch (other.ObservationDataCase) {
case ObservationDataOneofCase.CompressedData:
CompressedData = other.CompressedData;

case 50:
case 48: {
dimensionProperties_.AddEntriesFrom(input, _repeated_dimensionProperties_codec);
break;
}
case 56: {
observationType_ = (global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto) input.ReadEnum();
break;
}
}

21
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
foreach (int agentId in actionIds)
for (var i = 0; i < actionIds.Count; i++)
var agentId = actionIds[i];
if (lastActions.ContainsKey(agentId))
{
var actionBuffer = lastActions[agentId];

m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
//var tensorDataProbabilities = tensorProxy.Data as float[,];
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();

actionProbs.data.Dispose();
outputTensor.data.Dispose();
}
foreach (int agentId in actionIds)
for (var i = 0; i < actionIds.Count; i++)
var agentId = actionIds[i];
if (lastActions.ContainsKey(agentId))
{
var actionBuffer = lastActions[agentId];

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
foreach (int agentId in actionIds)
for (var i = 0; i < actionIds.Count; i++)
var agentId = actionIds[i];
List<float> memory;
if (!m_Memories.TryGetValue(agentId, out memory)
|| memory.Count < memorySize)

m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
foreach (int agentId in actionIds)
for (var i = 0; i < actionIds.Count; i++)
var agentId = actionIds[i];
List<float> memory;
if (!m_Memories.TryGetValue(agentId, out memory)
|| memory.Count < memorySize * m_MemoriesCount)

70
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensor = sensorComponents[sensorIndex];
if (!sensor.IsVisual())
if (sensor.GetObservationShape().Length == 3)
continue;
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
{
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
}
visObsIndex++;
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
if (sensor.GetObservationShape().Length == 2)
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
if (!tensorsNames.Contains(
TensorNames.ObservationPlaceholderPrefix + sensorIndex))
{
failedModelChecks.Add(
"The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).");
}
visObsIndex++;
}
var expectedVisualObs = model.GetNumVisualInputs();

}
/// <summary>
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponent">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckRankTwoObsShape(
TensorProxy tensorProxy, SensorComponent sensorComponent)
{
var shape = sensorComponent.GetObservationShape();
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;
var dim2T = tensorProxy.Width;
if ((dim1Bp != dim1T) || (dim2Bp != dim2T))
{
return $"An Observation of the model does not match. " +
$"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " +
$"was expecting [?x{dim1T}x{dim2T}].";
}
return null;
}
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>

for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensorComponent = sensorComponents[sensorIndex];
if (!sensorComponent.IsVisual())
if (sensorComponent.GetObservationShape().Length == 3)
{
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}
if (sensorComponent.GetObservationShape().Length == 2)
continue;
tensorTester[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sensorComponent);
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}
// If the model expects an input but it is not in this list

var totalVectorSensorSize = 0;
foreach (var sensorComp in sensorComponents)
{
if (sensorComp.IsVector())
if (sensorComp.GetObservationShape().Length == 1)
{
totalVectorSensorSize += sensorComp.GetObservationShape()[0];
}

var sensorSizes = "";
foreach (var sensorComp in sensorComponents)
{
if (sensorComp.IsVector())
if (sensorComp.GetObservationShape().Length == 1)
{
var vecSize = sensorComp.GetObservationShape()[0];
if (sensorSizes.Length == 0)

126
com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs


m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
}

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
{
tensorProxy.data?.Dispose();
tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1));

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
{
tensorProxy.shape = new long[0];
tensorProxy.data?.Dispose();

}
/// <summary>
/// Generates the Tensor corresponding to the VectorObservation input : Will be a two
/// dimensional float array of dimension [batchSize x vectorObservationSize].
/// It will use the Vector Observation data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
internal class VectorObservationGenerator : TensorGenerator.IGenerator
{
readonly ITensorAllocator m_Allocator;
List<int> m_SensorIndices = new List<int>();
ObservationWriter m_ObservationWriter = new ObservationWriter();
public VectorObservationGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void AddSensorIndex(int sensorIndex)
{
m_SensorIndices.Add(sensorIndex);
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var vecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;
foreach (var info in infos)
{
if (info.agentInfo.done)
{
// If the agent is done, we might have a stale reference to the sensors
// e.g. a dependent object might have been disposed.
// To avoid this, just fill observation with zeroes instead of calling sensor.Write.
TensorUtils.FillTensorBatch(tensorProxy, agentIndex, 0.0f);
}
else
{
var tensorOffset = 0;
// Write each sensor consecutively to the tensor
foreach (var sensorIndex in m_SensorIndices)
{
var sensor = info.sensors[sensorIndex];
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_ObservationWriter);
tensorOffset += numWritten;
}
Debug.AssertFormat(
tensorOffset == vecObsSizeT,
"mismatch between vector observation size ({0}) and number of observations written ({1})",
vecObsSizeT, tensorOffset
);
}
agentIndex++;
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the Recurrent input : Will be a two
/// dimensional float array of dimension [batchSize x memorySize].
/// It will use the Memory data contained in the agentInfo to fill the data

}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
foreach (var infoSensorPair in infos)
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
var infoSensorPair = infos[infoIndex];
var info = infoSensorPair.agentInfo;
List<float> memory;

m_Memories = memories;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
foreach (var infoSensorPair in infos)
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
var infoSensorPair = infos[infoIndex];
var info = infoSensorPair.agentInfo;
var offset = memorySize * m_MemoryIndex;
List<float> memory;

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
foreach (var infoSensorPair in infos)
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
var infoSensorPair = infos[infoIndex];
var info = infoSensorPair.agentInfo;
var pastAction = info.storedActions.DiscreteActions;
if (!pastAction.IsEmpty())

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
foreach (var infoSensorPair in infos)
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
var infoSensorPair = infos[infoIndex];
var agentInfo = infoSensorPair.agentInfo;
var maskList = agentInfo.discreteActionMasks;
for (var j = 0; j < maskSize; j++)

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
TensorUtils.FillTensorWithRandomNormal(tensorProxy, m_RandomNormal);

/// <summary>
/// Generates the Tensor corresponding to the Visual Observation input : Will be a 4
/// dimensional float array of dimension [batchSize x width x height x numChannels].
/// It will use the Texture input data contained in the agentInfo to fill the data
/// Generates the Tensor corresponding to the Observation input : Will be a multi
/// dimensional float array.
/// It will use the Observation data contained in the sensors to fill the data
internal class VisualObservationInputGenerator : TensorGenerator.IGenerator
internal class ObservationGenerator : TensorGenerator.IGenerator
readonly int m_SensorIndex;
List<int> m_SensorIndices = new List<int>();
public VisualObservationInputGenerator(
int sensorIndex, ITensorAllocator allocator)
public ObservationGenerator(ITensorAllocator allocator)
m_SensorIndex = sensorIndex;
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
public void AddSensorIndex(int sensorIndex)
{
m_SensorIndices.Add(sensorIndex);
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
foreach (var infoSensorPair in infos)
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
var sensor = infoSensorPair.sensors[m_SensorIndex];
if (infoSensorPair.agentInfo.done)
var info = infos[infoIndex];
if (info.agentInfo.done)
{
// If the agent is done, we might have a stale reference to the sensors
// e.g. a dependent object might have been disposed.

else
{
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, 0);
sensor.Write(m_ObservationWriter);
var tensorOffset = 0;
// Write each sensor consecutively to the tensor
for (var sensorIndexIndex = 0; sensorIndexIndex < m_SensorIndices.Count; sensorIndexIndex++)
{
var sensorIndex = m_SensorIndices[sensorIndexIndex];
var sensor = info.sensors[sensorIndex];
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_ObservationWriter);
tensorOffset += numWritten;
}
}
agentIndex++;
}

46
com.unity.ml-agents/Runtime/Inference/ModelRunner.cs


TensorApplier m_TensorApplier;
NNModel m_Model;
string m_ModelName;
IReadOnlyList<TensorProxy> m_InferenceOutputs;
List<TensorProxy> m_InferenceOutputs;
Dictionary<string, Tensor> m_InputsByName;
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();
SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator();

{
Model barracudaModel;
m_Model = model;
m_ModelName = model.name;
m_InferenceDevice = inferenceDevice;
m_TensorAllocator = new TensorCachingAllocator();
if (model != null)

seed, m_TensorAllocator, m_Memories, barracudaModel);
m_TensorApplier = new TensorApplier(
actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel);
m_InputsByName = new Dictionary<string, Tensor>();
m_InferenceOutputs = new List<TensorProxy>();
}
public InferenceDevice InferenceDevice

get { return m_Model; }
}
static Dictionary<string, Tensor> PrepareBarracudaInputs(IEnumerable<TensorProxy> infInputs)
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs)
var inputs = new Dictionary<string, Tensor>();
foreach (var inp in infInputs)
m_InputsByName.Clear();
for (var i = 0; i < infInputs.Count; i++)
inputs[inp.name] = inp.data;
var inp = infInputs[i];
m_InputsByName[inp.name] = inp.data;
return inputs;
}
public void Dispose()

m_TensorAllocator?.Reset(false);
}
List<TensorProxy> FetchBarracudaOutputs(string[] names)
void FetchBarracudaOutputs(string[] names)
var outputs = new List<TensorProxy>();
m_InferenceOutputs.Clear();
outputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n));
m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n));
return outputs;
}
public void PutObservations(AgentInfo info, List<ISensor> sensors)

}
Profiler.BeginSample("ModelRunner.DecideAction");
Profiler.BeginSample(m_ModelName);
Profiler.BeginSample($"MLAgents.{m_Model.name}.GenerateTensors");
Profiler.BeginSample($"GenerateTensors");
Profiler.BeginSample($"MLAgents.{m_Model.name}.PrepareBarracudaInputs");
var inputs = PrepareBarracudaInputs(m_InferenceInputs);
Profiler.BeginSample($"PrepareBarracudaInputs");
PrepareBarracudaInputs(m_InferenceInputs);
Profiler.BeginSample($"MLAgents.{m_Model.name}.ExecuteGraph");
m_Engine.Execute(inputs);
Profiler.BeginSample($"ExecuteGraph");
m_Engine.Execute(m_InputsByName);
Profiler.BeginSample($"MLAgents.{m_Model.name}.FetchBarracudaOutputs");
m_InferenceOutputs = FetchBarracudaOutputs(m_OutputNames);
Profiler.BeginSample($"FetchBarracudaOutputs");
FetchBarracudaOutputs(m_OutputNames);
Profiler.BeginSample($"MLAgents.{m_Model.name}.ApplyTensors");
Profiler.BeginSample($"ApplyTensors");
Profiler.EndSample();
Profiler.EndSample(); // end name
Profiler.EndSample(); // end ModelRunner.DecideAction
m_Infos.Clear();

7
com.unity.ml-agents/Runtime/Inference/TensorApplier.cs


/// </param>
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions);
void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions);
}
readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>();

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated applier.</exception>
public void ApplyTensors(
IEnumerable<TensorProxy> tensors, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
IReadOnlyList<TensorProxy> tensors, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
foreach (var tensor in tensors)
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
var tensor = tensors[tensorIndex];
if (!m_Dict.ContainsKey(tensor.name))
{
throw new UnityAgentsException(

63
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


/// the tensor's data.
/// </param>
void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos);
TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos);
}
readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>();

public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
{
// Loop through the sensors on a representative agent.
// For vector observations, add the index to the (single) VectorObservationGenerator
// For visual observations, make a VisualObservationInputGenerator
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
VectorObservationGenerator vecObsGen = null;
ObservationGenerator vecObsGen = null;
// TODO generalize - we currently only have vector or visual, but can't handle "2D" observations
var isVectorSensor = (shape.Length == 1);
if (isVectorSensor)
var rank = shape.Length;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
if (vecObsGen == null)
{
vecObsGen = new VectorObservationGenerator(allocator);
}
vecObsGen.AddSensorIndex(sensorIndex);
}
else
{
m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] =
new VisualObservationInputGenerator(sensorIndex, allocator);
visIndex++;
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.ObservationPlaceholderPrefix + sensorIndex;
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.VisualObservationPlaceholderPrefix + visIndex;
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
}
if (vecObsGen != null)
{
m_Dict[TensorNames.VectorObservationPlaceholder] = vecObsGen;
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>
public void GenerateTensors(
IEnumerable<TensorProxy> tensors, int currentBatchSize, IEnumerable<AgentInfoSensorsPair> infos)
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<AgentInfoSensorsPair> infos)
foreach (var tensor in tensors)
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
var tensor = tensors[tensorIndex];
if (!m_Dict.ContainsKey(tensor.name))
{
throw new UnityAgentsException(

1
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string recurrentInPlaceholderH = "recurrent_in_h";
public const string recurrentInPlaceholderC = "recurrent_in_c";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string ObservationPlaceholderPrefix = "obs_";
public const string PreviousActionPlaceholder = "prev_action";
public const string ActionMaskPlaceholder = "action_masks";
public const string RandomNormalEpsilonPlaceholder = "epsilon";

9
com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs


get { return m_BehaviorName + "?team=" + TeamId; }
}
internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGenerator heuristic)
internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorManager)
return new HeuristicPolicy(heuristic, actionSpec);
return new HeuristicPolicy(actuatorManager, actionSpec);
case BehaviorType.InferenceOnly:
{
if (m_Model == null)

}
else
{
return new HeuristicPolicy(heuristic, actionSpec);
return new HeuristicPolicy(actuatorManager, actionSpec);
return new HeuristicPolicy(heuristic, actionSpec);
return new HeuristicPolicy(actuatorManager, actionSpec);
}
}

}
agent.ReloadPolicy();
}
}
}

11
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


namespace Unity.MLAgents.Policies
{
/// <summary>
/// The Heuristic Policy uses a hards coded Heuristic method
/// The Heuristic Policy uses a hard-coded Heuristic method
public delegate void ActionGenerator(in ActionBuffers actionBuffers);
ActionGenerator m_Heuristic;
ActuatorManager m_ActuatorManager;
ActionBuffers m_ActionBuffers;
bool m_Done;
bool m_DecisionRequested;

/// <inheritdoc />
public HeuristicPolicy(ActionGenerator heuristic, ActionSpec actionSpec)
public HeuristicPolicy(ActuatorManager actuatorManager, ActionSpec actionSpec)
m_Heuristic = heuristic;
m_ActuatorManager = actuatorManager;
var numContinuousActions = actionSpec.NumContinuousActions;
var numDiscreteActions = actionSpec.NumDiscreteActions;
var continuousDecision = new ActionSegment<float>(new float[numContinuousActions], 0, numContinuousActions);

{
if (!m_Done && m_DecisionRequested)
{
m_Heuristic.Invoke(m_ActionBuffers);
m_ActuatorManager.ApplyHeuristic(m_ActionBuffers);
}
m_DecisionRequested = false;
return ref m_ActionBuffers;

14
com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs


using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Analytics;
namespace Unity.MLAgents.Policies
{

string m_FullyQualifiedBehaviorName;
ActionSpec m_ActionSpec;
ActionBuffers m_LastActionBuffer;
private bool m_AnalyticsSent = false;
internal ICommunicator m_Communicator;

{
m_FullyQualifiedBehaviorName = fullyQualifiedBehaviorName;
m_Communicator = Academy.Instance.Communicator;
m_Communicator.SubscribeBrain(m_FullyQualifiedBehaviorName, actionSpec);
m_Communicator?.SubscribeBrain(m_FullyQualifiedBehaviorName, actionSpec);
m_ActionSpec = actionSpec;
}

if (!m_AnalyticsSent)
{
m_AnalyticsSent = true;
TrainingAnalytics.RemotePolicyInitialized(
m_FullyQualifiedBehaviorName,
sensors,
m_ActionSpec
);
}
m_AgentId = info.episodeId;
m_Communicator?.PutObservations(m_FullyQualifiedBehaviorName, info, sensors);
}

8
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


namespace Unity.MLAgents.Sensors
{
internal class BufferSensor : ISensor, IDimensionPropertiesSensor
internal class BufferSensor : ISensor, IDimensionPropertiesSensor, IBuiltInSensor
{
private int m_MaxNumObs;
private int m_ObsSize;

public string GetName()
{
return "BufferSensor";
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.BufferSensor;
}
}

15
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


using UnityEngine;
using UnityEngine.Rendering;
namespace Unity.MLAgents.Sensors
{

public class CameraSensor : ISensor
public class CameraSensor : ISensor, IBuiltInSensor
{
Camera m_Camera;
int m_Width;

/// <returns name="texture2D">Texture2D to render to.</returns>
public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height)
{
if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null)
{
Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render.");
}
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
var oldRec = obsCamera.rect;
obsCamera.rect = new Rect(0f, 0f, 1f, 1f);

Object.Destroy(texture);
}
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.CameraSensor;
}
}
}

24
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


}
/// <summary>
/// 1D write access at a specified index. Use AddRange if possible instead.
/// 1D write access at a specified index. Use AddList if possible instead.
/// </summary>
/// <param name="index">Index to write to.</param>
public float this[int index]

/// </summary>
/// <param name="data"></param>
/// <param name="writeOffset">Optional write offset.</param>
[Obsolete("Use AddList() for better performance")]
public void AddRange(IEnumerable<float> data, int writeOffset = 0)
{
if (m_Data != null)

{
m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val;
index++;
}
}
}
public void AddList(IList<float> data, int writeOffset = 0)
{
if (m_Data != null)
{
for (var index = 0; index < data.Count; index++)
{
var val = data[index];
m_Data[index + m_Offset + writeOffset] = val;
}
}
else
{
for (var index = 0; index < data.Count; index++)
{
var val = data[index];
m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val;
}
}
}

10
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


/// <summary>
/// A sensor implementation that supports ray cast-based observations.
/// </summary>
public class RayPerceptionSensor : ISensor
public class RayPerceptionSensor : ISensor, IBuiltInSensor
{
float[] m_Observations;
int[] m_Shape;

rayOutput.ToFloatArray(numDetectableTags, rayIndex, m_Observations);
}
// Finally, add the observations to the ObservationWriter
writer.AddRange(m_Observations);
writer.AddList(m_Observations);
}
return m_Observations.Length;
}

public virtual SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.RayPerceptionSensor;
}
/// <summary>

9
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


/// <summary>
/// Abstract base class for reflection-based sensors.
/// </summary>
internal abstract class ReflectionSensorBase : ISensor
internal abstract class ReflectionSensorBase : ISensor, IBuiltInSensor
{
protected object m_Object;

{
return m_SensorName;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.ReflectionSensor;
}
}
}

9
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


/// <summary>
/// Sensor class that wraps a [RenderTexture](https://docs.unity3d.com/ScriptReference/RenderTexture.html) instance.
/// </summary>
public class RenderTextureSensor : ISensor
public class RenderTextureSensor : ISensor, IBuiltInSensor
{
RenderTexture m_RenderTexture;
bool m_Grayscale;

{
return m_CompressionType;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.RenderTextureSensor;
}
/// <summary>
/// Converts a RenderTexture to a 2D texture.

3
com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs


using UnityEngine;
using System;
namespace Unity.MLAgents.Sensors
{

/// Whether the observation is visual or not.
/// </summary>
/// <returns>True if the observation is visual, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVisual()
{
var shape = GetObservationShape();

/// Whether the observation is vector or not.
/// </summary>
/// <returns>True if the observation is vector, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVector()
{
var shape = GetObservationShape();

7
com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs


{
// Check for compatibility with the other Agents' Sensors
// TODO make sure this only checks once per agent
Debug.Assert(m_SensorShapes.Count == sensors.Count, $"Number of Sensors must match. {m_SensorShapes.Count} != {sensors.Count}");
Debug.AssertFormat(
m_SensorShapes.Count == sensors.Count,
"Number of Sensors must match. {0} != {1}",
m_SensorShapes.Count,
sensors.Count
);
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
{
var cachedShape = m_SensorShapes[i];

11
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


/// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
/// Currently, observations are stacked on the last dimension.
/// </summary>
public class StackingSensor : ISparseChannelSensor
public class StackingSensor : ISparseChannelSensor, IBuiltInSensor
{
/// <summary>
/// The wrapped sensor.

for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
writer.AddRange(m_StackedObservations[obsIndex], numWritten);
writer.AddList(m_StackedObservations[obsIndex], numWritten);
numWritten += m_UnstackedObservationSize;
}
}

}
}
return compressionMapping;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor;
return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown;
}
}
}

24
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using UnityEngine;

/// <summary>
/// A sensor implementation for vector observations.
/// </summary>
public class VectorSensor : ISensor
public class VectorSensor : ISensor, IBuiltInSensor
{
// TODO use float[] instead
// TODO allow setting float[]

m_Observations.Add(0);
}
}
writer.AddRange(m_Observations);
writer.AddList(m_Observations);
return expectedObservations;
}

return SensorCompressionType.None;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.VectorSensor;
}
void Clear()
{
m_Observations.Clear();

/// Adds a collection of float observations to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
[Obsolete("Use AddObservation(IList<float>) for better performance.")]
}
}
/// <summary>
/// Adds a list or array of float observations to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(IList<float> observation)
{
for (var i = 0; i < observation.Count; i++)
{
AddFloatObs(observation[i]);
}
}

12
com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs


/// </summary>
internal class EngineConfigurationChannel : SideChannel
{
enum ConfigurationType : int
internal enum ConfigurationType : int
{
ScreenResolution = 0,
QualityLevel = 1,

break;
case ConfigurationType.TimeScale:
var timeScale = msg.ReadFloat32();
timeScale = Mathf.Clamp(timeScale, 1, 100);
// There's an upper limit for the timeScale in the editor (but not in the player)
// Always ensure that timeScale >= 1 also,
#if UNITY_EDITOR
const float maxTimeScale = 100f;
#else
const float maxTimeScale = float.PositiveInfinity;
#endif
timeScale = Mathf.Clamp(timeScale, 1, maxTimeScale);
Time.timeScale = timeScale;
break;
case ConfigurationType.TargetFrameRate:

26
com.unity.ml-agents/Runtime/SideChannels/SideChannelManager.cs


/// <returns></returns>
internal static byte[] GetSideChannelMessage(Dictionary<Guid, SideChannel> sideChannels)
{
if (!HasOutgoingMessages(sideChannels))
{
// Early out so that we don't create the MemoryStream or BinaryWriter.
// This is the most common case.
return Array.Empty<byte>();
}
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))

return memStream.ToArray();
}
}
}
/// <summary>
/// Check whether any of the sidechannels have queued messages.
/// </summary>
/// <param name="sideChannels"></param>
/// <returns></returns>
static bool HasOutgoingMessages(Dictionary<Guid, SideChannel> sideChannels)
{
foreach (var sideChannel in sideChannels.Values)
{
var messageList = sideChannel.MessageQueue;
if (messageList.Count > 0)
{
return true;
}
}
return false;
}
/// <summary>

18
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs


manager.WriteActionMask();
Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask()));
}
[Test]
public void TestHeuristic()
{
var manager = new ActuatorManager(2);
var va1 = new TestActuator(ActionSpec.MakeDiscrete(1, 2, 3), "name");
var va2 = new TestActuator(ActionSpec.MakeDiscrete(3, 2, 1, 8), "name1");
manager.Add(va1);
manager.Add(va2);
var actionBuf = new ActionBuffers(Array.Empty<float>(), new[] { 0, 0, 0, 0, 0, 0, 0 });
manager.ApplyHeuristic(actionBuf);
Assert.IsTrue(va1.m_HeuristicCalled);
Assert.AreEqual(va1.m_DiscreteBufferSize, 3);
Assert.IsTrue(va2.m_HeuristicCalled);
Assert.AreEqual(va2.m_DiscreteBufferSize, 4);
}
}
}

11
com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs


using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
internal class TestActuator : IActuator
internal class TestActuator : IActuator, IHeuristicProvider
public bool m_HeuristicCalled;
public int m_DiscreteBufferSize;
public TestActuator(ActionSpec actuatorSpace, string name)
{
ActionSpec = actuatorSpace;

public void ResetData()
{
}
public void Heuristic(in ActionBuffers actionBuffersOut)
{
m_HeuristicCalled = true;
m_DiscreteBufferSize = actionBuffersOut.DiscreteActions.Length;
}
}
}

19
com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs


using System;
using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;

[TestFixture]
public class VectorActuatorTests
{
class TestActionReceiver : IActionReceiver
class TestActionReceiver : IActionReceiver, IHeuristicProvider
public bool HeuristicCalled;
public void OnActionReceived(ActionBuffers actionBuffers)
{

public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(Branch, Mask);
}
public void Heuristic(in ActionBuffers actionBuffersOut)
{
HeuristicCalled = true;
}
}

va.WriteDiscreteActionMask(bdam);
Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask()));
}
[Test]
public void TestHeuristic()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");
va.Heuristic(new ActionBuffers(Array.Empty<float>(), va.ActionSpec.BranchSizes));
Assert.IsTrue(ar.HeuristicCalled);
}
}
}

19
com.unity.ml-agents/Tests/Editor/Analytics/InferenceAnalyticsTests.cs


[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel));
var go = new GameObject("SensorA");
sensor_21_20_3 = go.AddComponent<Test3DSensorComponent>();

Assert.AreEqual(3, continuousEvent.ObservationSpecs[0].DimensionInfos.Length);
Assert.AreEqual(20, continuousEvent.ObservationSpecs[0].DimensionInfos[0].Size);
Assert.AreEqual("None", continuousEvent.ObservationSpecs[0].CompressionType);
Assert.AreEqual(Test3DSensor.k_BuiltInSensorType, continuousEvent.ObservationSpecs[0].BuiltInSensorType);
Assert.AreNotEqual(null, continuousEvent.ModelHash);
// Make sure nested fields get serialized

Assert.IsTrue(jsonString.Contains("NumDiscreteActions"));
Assert.IsTrue(jsonString.Contains("SensorName"));
Assert.IsTrue(jsonString.Contains("Flags"));
}
[Test]
public void TestBarracudaPolicy()
{
// Explicitly request decisions for a policy so we get code coverage on the event sending
using (new AnalyticsUtils.DisableAnalyticsSending())
{
var sensors = new List<ISensor> { sensor_21_20_3.Sensor, sensor_20_22_3.Sensor };
var policy = new BarracudaPolicy(GetContinuous2vis8vec2actionActionSpec(), continuousONNXModel, InferenceDevice.CPU, "testBehavior");
policy.RequestDecision(new AgentInfo(), sensors);
}
Academy.Instance.Dispose();
}
}
}

6
com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs


namespace Unity.MLAgents.Tests
{
[TestFixture]
public class BehaviorParameterTests
public class BehaviorParameterTests : IHeuristicProvider
static void DummyHeuristic(in ActionBuffers actionsOut)
public void Heuristic(in ActionBuffers actionsOut)
{
// No-op
}

Assert.Throws<UnityAgentsException>(() =>
{
bp.GeneratePolicy(actionSpec, DummyHeuristic);
bp.GeneratePolicy(actionSpec, new ActuatorManager());
});
}
}

32
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


using NUnit.Framework;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Analytics;
using Unity.MLAgents.CommunicatorObjects;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests

Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
sparseChannelSensor.Mapping = new[] { 0, 0, 0, 1, 1, 1 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
}
[Test]
public void TestDefaultTrainingEvents()
{
var trainingEnvInit = new TrainingEnvironmentInitialized
{
PythonVersion = "test",
};
var trainingEnvInitEvent = trainingEnvInit.ToTrainingEnvironmentInitializedEvent();
Assert.AreEqual(trainingEnvInit.PythonVersion, trainingEnvInitEvent.TrainerPythonVersion);
var trainingBehavInit = new TrainingBehaviorInitialized
{
BehaviorName = "testBehavior",
ExtrinsicRewardEnabled = true,
CuriosityRewardEnabled = true,
RecurrentEnabled = true,
SelfPlayEnabled = true,
};
var trainingBehavInitEvent = trainingBehavInit.ToTrainingBehaviorInitializedEvent();
Assert.AreEqual(trainingBehavInit.BehaviorName, trainingBehavInitEvent.BehaviorName);
Assert.AreEqual(RewardSignals.Extrinsic | RewardSignals.Curiosity, trainingBehavInitEvent.RewardSignalFlags);
Assert.AreEqual(TrainingFeatures.Recurrent | TrainingFeatures.SelfPlay, trainingBehavInitEvent.TrainingFeatureFlags);
}
}
}

2
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs


const int batchSize = 4;
var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll);
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
var generator = new ObservationGenerator(alloc);
generator.AddSensorIndex(0); // ObservableAttribute (size 1)
generator.AddSensorIndex(1); // TestSensor (size 0)
generator.AddSensorIndex(2); // TestSensor (size 0)

9
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


return Sensor.GetObservationShape();
}
}
public class Test3DSensor : ISensor
public class Test3DSensor : ISensor, IBuiltInSensor
// Dummy value for the IBuiltInSensor interface
public const int k_BuiltInSensorType = -42;
public Test3DSensor(string name, int width, int height, int channels)
{

public string GetName()
{
return m_Name;
}
public BuiltInSensorType GetBuiltInSensorType()
{
return (BuiltInSensorType)k_BuiltInSensorType;
}
}

2
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
Assert.IsTrue(cameraComponent.IsVisual());
Assert.IsFalse(cameraComponent.IsVector());
var sensor = cameraComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());

12
com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs


writer[0] = 3f;
Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer);
// AddRange
// AddList
writer.AddRange(new[] { 4f, 5f });
writer.AddList(new[] { 4f, 5f });
// AddRange with offset
// AddList with offset
writer.AddRange(new[] { 6f, 7f });
writer.AddList(new[] { 6f, 7f });
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer);
}

Assert.AreEqual(2f, t.data[1, 1]);
Assert.AreEqual(3f, t.data[1, 2]);
// AddRange
// AddList
t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,

writer.SetTarget(t, 1, 1);
writer.AddRange(new[] { -1f, -2f });
writer.AddList(new[] { -1f, -2f });
Assert.AreEqual(0f, t.data[0, 0]);
Assert.AreEqual(0f, t.data[0, 1]);
Assert.AreEqual(0f, t.data[0, 2]);

2
com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs


var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
Assert.IsTrue(renderTexComponent.IsVisual());
Assert.IsFalse(renderTexComponent.IsVector());
var sensor = renderTexComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());

15
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


{
return Mapping;
}
}
[Test]

var expected4 = sensor.CreateEmptyPNG();
expected4 = expected4.Concat(Array.ConvertAll(new[] { 10f, 11f, 12f }, (z) => (byte)z)).ToArray();
Assert.AreEqual(sensor.GetCompressedObservation(), expected4);
}
[Test]
public void TestStackingSensorBuiltInSensorType()
{
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new[] { 2, 2, 4 };
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown);
var vectorSensor = new VectorSensor(4);
var stackedVectorSensor = new StackingSensor(vectorSensor, 4);
Assert.AreEqual(stackedVectorSensor.GetBuiltInSensorType(), BuiltInSensorType.VectorSensor);
}
}
}

4
com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs


var behaviorParams = gameObject.AddComponent<BehaviorParameters>();
behaviorParams.BrainParameters.VectorObservationSize = 3;
behaviorParams.BrainParameters.NumStackedVectorObservations = 2;
behaviorParams.BrainParameters.VectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
behaviorParams.BrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 2);
behaviorParams.BrainParameters.VectorActionDescriptions = new[] { "Continuous1", "TestActionA", "TestActionB" };
behaviorParams.BrainParameters.ActionSpec = new ActionSpec(1, new []{2, 2});
behaviorParams.BehaviorName = "TestBehavior";
behaviorParams.TeamId = 42;
behaviorParams.UseChildSensors = true;

2
docs/Background-Unity.md


If you are not familiar with the [Unity Engine](https://unity3d.com/unity), we
highly recommend the [Unity Manual](https://docs.unity3d.com/Manual/index.html)
and [Tutorials page](https://unity3d.com/learn/tutorials). The
[Roll-a-ball tutorial](https://unity3d.com/learn/tutorials/s/roll-ball-tutorial)
[Roll-a-ball tutorial](https://learn.unity.com/project/roll-a-ball)
is a fantastic resource to learn all the basic concepts of Unity to get started
with the ML-Agents Toolkit:

11
docs/Migrating.md


- `UnityEnvironment.API_VERSION` in environment.py
([example](https://github.com/Unity-Technologies/ml-agents/blob/b255661084cb8f701c716b040693069a3fb9a257/ml-agents-envs/mlagents/envs/environment.py#L45))
# Migrating
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations
- If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator
handle the generation of actions when an Agent is running in heuristic mode.
- `VectorSensor.AddObservation(IEnumerable<float>)` is deprecated. Use `VectorSensor.AddObservation(IList<float>)`
instead.
- `ObservationWriter.AddRange()` is deprecated. Use `ObservationWriter.AddList()` instead.
# Migrating
## Migrating to Release 11
### Agent virtual method deprecation

9
docs/Python-API.md


A `BehaviorSpec` has the following fields :
- `sensor_specs` is a List of `SensorSpec` objects : Each `SensorSpec`
- `observation_specs` is a List of `ObservationSpec` objects : Each `ObservationSpec`
data should be processed in the corresponding dimension. Note that the `SensorSpec`
have the same ordering as the ordering of observations in the DecisionSteps,
DecisionStep, TerminalSteps and TerminalStep.
data should be processed in the corresponding dimension. `observation_type` is an enum
corresponding to what type of observation is generating the data (i.e., default, goal,
etc). Note that the `ObservationSpec` have the same ordering as the ordering of observations
in the DecisionSteps, DecisionStep, TerminalSteps and TerminalStep.
- `action_spec` is an `ActionSpec` namedtuple that defines the number and types
of actions for the Agent.

10
docs/Training-ML-Agents.md


mlagents-learn --help
```
These additional CLI arguments are grouped into environment, engine and checkpoint. The available settings and example values are shown below.
These additional CLI arguments are grouped into environment, engine, checkpoint and torch.
The available settings and example values are shown below.
#### Environment settings

force: true
train_model: false
inference: false
```
#### Torch settings:
```yaml
torch_settings:
device: cpu
```
### Behavior Configurations

16
gym-unity/gym_unity/envs/__init__.py


def _get_n_vis_obs(self) -> int:
result = 0
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 3:
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 3:
result.append(sen_spec.shape)
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 3:
result.append(obs_spec.shape)
return result
def _get_vis_obs_list(

def _get_vec_obs_size(self) -> int:
result = 0
for sen_spec in self.group_spec.sensor_specs:
if len(sen_spec.shape) == 1:
result += sen_spec.shape[0]
for obs_spec in self.group_spec.observation_specs:
if len(obs_spec.shape) == 1:
result += obs_spec.shape[0]
return result
def render(self, mode="rgb_array"):

6
gym-unity/gym_unity/tests/test_gym.py


TerminalSteps,
BehaviorMapping,
)
from mlagents.trainers.tests.dummy_config import create_sensor_specs_with_shapes
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes
def test_gym_wrapper():

obs_shapes = [(vector_observation_space_size,)]
for _ in range(number_visual_observations):
obs_shapes += [(8, 8, 3)]
sen_spec = create_sensor_specs_with_shapes(obs_shapes)
return BehaviorSpec(sen_spec, action_spec)
obs_spec = create_observation_specs_with_shapes(obs_shapes)
return BehaviorSpec(obs_spec, action_spec)
def create_mock_vector_steps(specs, num_agents=1, number_visual_observations=0):

52
ml-agents-envs/mlagents_envs/base_env.py


Any,
Mapping as MappingType,
)
from enum import IntFlag
from enum import IntFlag, Enum
import numpy as np
from mlagents_envs.exception import UnityActionException

reward: float
agent_id: AgentId
action_mask: Optional[List[np.ndarray]]
team_manager_id: Optional[str]
team_manager_id: int
class DecisionSteps(Mapping):

this simulation step.
"""
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id=None):
def __init__(self, obs, reward, agent_id, action_mask, team_manager_id):
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
@property

agent_mask = []
for mask in self.action_mask:
agent_mask.append(mask[agent_index])
team_manager_id = None
if self.team_manager_id is not None and self.team_manager_id != "":
team_manager_id = self.team_manager_id[agent_index]
team_manager_id = self.team_manager_id[agent_index]
return DecisionStep(
obs=agent_obs,
reward=self.reward[agent_index],

:param spec: The BehaviorSpec for the DecisionSteps
"""
obs: List[np.ndarray] = []
for sen_spec in spec.sensor_specs:
for sen_spec in spec.observation_specs:
obs += [np.zeros((0,) + sen_spec.shape, dtype=np.float32)]
return DecisionSteps(
obs=obs,

team_manager_id=None,
team_manager_id=np.zeros(0, dtype=np.int32),
)

reward: float
interrupted: bool
agent_id: AgentId
team_manager_id: Optional[str]
team_manager_id: int
class TerminalSteps(Mapping):

across simulation steps.
"""
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id=None):
def __init__(self, obs, reward, interrupted, agent_id, team_manager_id):
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
self.team_manager_id: Optional[List[str]] = team_manager_id

agent_obs = []
for batched_obs in self.obs:
agent_obs.append(batched_obs[agent_index])
team_manager_id = None
if self.team_manager_id is not None and self.team_manager_id != "":
team_manager_id = self.team_manager_id[agent_index]
team_manager_id = self.team_manager_id[agent_index]
return TerminalStep(
obs=agent_obs,
reward=self.reward[agent_index],

:param spec: The BehaviorSpec for the TerminalSteps
"""
obs: List[np.ndarray] = []
for sen_spec in spec.sensor_specs:
for sen_spec in spec.observation_specs:
obs += [np.zeros((0,) + sen_spec.shape, dtype=np.float32)]
return TerminalSteps(
obs=obs,

team_manager_id=None,
team_manager_id=np.zeros(0, dtype=np.int32),
)

VARIABLE_SIZE = 4
class SensorSpec(NamedTuple):
class ObservationType(Enum):
"""
An Enum which defines the type of information carried in the observation
of the agent.
"""
# Observation information is generic.
DEFAULT = 0
# Observation contains goal information for current task.
GOAL = 1
# Observation contains reward information for current task.
REWARD = 2
# Observation contains a message from another agent.
MESSAGE = 3
class ObservationSpec(NamedTuple):
"""
A NamedTuple containing information about the observation of Agents.
- shape is a Tuple of int : It corresponds to the shape of

- observation_type is an enum of ObservationType.
observation_type: ObservationType
class BehaviorSpec(NamedTuple):

- sensor_specs is a List of SensorSpec NamedTuple containing
- observation_specs is a List of ObservationSpec NamedTuple containing
information about the information of the Agent's observations such as their shapes.
The order of the SensorSpec is the same as the order of the observations of an
agent.

sensor_specs: List[SensorSpec]
observation_specs: List[ObservationSpec]
action_spec: ActionSpec

17
ml-agents-envs/mlagents_envs/communicator.py


from typing import Optional
from typing import Callable, Optional
# Function to call while waiting for a connection timeout.
# This should raise an exception if it needs to break from waiting for the timeout.
PollCallback = Callable[[], None]
class Communicator:

:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
"""
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
def initialize(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
:param poll_callback: Optional callback to be used while polling the connection.
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
def exchange(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> Optional[UnityOutputProto]:
:param poll_callback: Optional callback to be used while polling the connection.
:return: The UnityOutputs generated by the Environment
"""

6
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py


name='mlagents_envs/communicator_objects/agent_info.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\tJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n3mlagents_envs/communicator_objects/agent_info.proto\x12\x14\x63ommunicator_objects\x1a\x34mlagents_envs/communicator_objects/observation.proto\"\xea\x01\n\x0e\x41gentInfoProto\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12<\n\x0cobservations\x18\r \x03(\x0b\x32&.communicator_objects.ObservationProto\x12\x17\n\x0fteam_manager_id\x18\x0e \x01(\x05J\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x05\x10\x06J\x04\x08\x06\x10\x07J\x04\x08\x0c\x10\rB%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
,
dependencies=[mlagents__envs_dot_communicator__objects_dot_observation__pb2.DESCRIPTOR,])

options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='team_manager_id', full_name='communicator_objects.AgentInfoProto.team_manager_id', index=6,
number=14, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
number=14, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),

4
ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi


max_step_reached = ... # type: builtin___bool
id = ... # type: builtin___int
action_mask = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___bool]
team_manager_id = ... # type: typing___Text
team_manager_id = ... # type: builtin___int
@property
def observations(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]: ...

id : typing___Optional[builtin___int] = None,
action_mask : typing___Optional[typing___Iterable[builtin___bool]] = None,
observations : typing___Optional[typing___Iterable[mlagents_envs___communicator_objects___observation_pb2___ObservationProto]] = None,
team_manager_id : typing___Optional[typing___Text] = None,
team_manager_id : typing___Optional[builtin___int] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> AgentInfoProto: ...

11
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py


name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\x94\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xaf\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='trainingAnalytics', full_name='communicator_objects.UnityRLCapabilitiesProto.trainingAnalytics', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=80,
serialized_end=228,
serialized_end=255,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO

6
ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi


concatenatedPngObservations = ... # type: builtin___bool
compressedChannelMapping = ... # type: builtin___bool
hybridActions = ... # type: builtin___bool
trainingAnalytics = ... # type: builtin___bool
def __init__(self,
*,

hybridActions : typing___Optional[builtin___bool] = None,
trainingAnalytics : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...

def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics"]) -> None: ...

56
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py


name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\xbb\x02\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x81\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*F\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x12\x0b\n\x07MESSAGE\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(

],
containing_type=None,
options=None,
serialized_start=396,
serialized_end=437,
serialized_start=466,
serialized_end=507,
_OBSERVATIONTYPEPROTO = _descriptor.EnumDescriptor(
name='ObservationTypeProto',
full_name='communicator_objects.ObservationTypeProto',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='DEFAULT', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='GOAL', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='REWARD', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MESSAGE', index=3, number=3,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=509,
serialized_end=579,
)
_sym_db.RegisterEnumDescriptor(_OBSERVATIONTYPEPROTO)
ObservationTypeProto = enum_type_wrapper.EnumTypeWrapper(_OBSERVATIONTYPEPROTO)
DEFAULT = 0
GOAL = 1
REWARD = 2
MESSAGE = 3

extension_ranges=[],
oneofs=[
],
serialized_start=349,
serialized_end=374,
serialized_start=419,
serialized_end=444,
)
_OBSERVATIONPROTO = _descriptor.Descriptor(

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='observation_type', full_name='communicator_objects.ObservationProto.observation_type', index=6,
number=7, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],

index=0, containing_type=None, fields=[]),
],
serialized_start=79,
serialized_end=394,
serialized_end=464,
_OBSERVATIONPROTO.fields_by_name['observation_type'].enum_type = _OBSERVATIONTYPEPROTO
_OBSERVATIONPROTO.oneofs_by_name['observation_data'].fields.append(
_OBSERVATIONPROTO.fields_by_name['compressed_data'])
_OBSERVATIONPROTO.fields_by_name['compressed_data'].containing_oneof = _OBSERVATIONPROTO.oneofs_by_name['observation_data']

DESCRIPTOR.message_types_by_name['ObservationProto'] = _OBSERVATIONPROTO
DESCRIPTOR.enum_types_by_name['CompressionTypeProto'] = _COMPRESSIONTYPEPROTO
DESCRIPTOR.enum_types_by_name['ObservationTypeProto'] = _OBSERVATIONTYPEPROTO
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
ObservationProto = _reflection.GeneratedProtocolMessageType('ObservationProto', (_message.Message,), dict(

27
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi


NONE = typing___cast('CompressionTypeProto', 0)
PNG = typing___cast('CompressionTypeProto', 1)
class ObservationTypeProto(builtin___int):
DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ...
@classmethod
def Name(cls, number: builtin___int) -> builtin___str: ...
@classmethod
def Value(cls, name: builtin___str) -> 'ObservationTypeProto': ...
@classmethod
def keys(cls) -> typing___List[builtin___str]: ...
@classmethod
def values(cls) -> typing___List['ObservationTypeProto']: ...
@classmethod
def items(cls) -> typing___List[typing___Tuple[builtin___str, 'ObservationTypeProto']]: ...
DEFAULT = typing___cast('ObservationTypeProto', 0)
GOAL = typing___cast('ObservationTypeProto', 1)
REWARD = typing___cast('ObservationTypeProto', 2)
MESSAGE = typing___cast('ObservationTypeProto', 3)
DEFAULT = typing___cast('ObservationTypeProto', 0)
GOAL = typing___cast('ObservationTypeProto', 1)
REWARD = typing___cast('ObservationTypeProto', 2)
MESSAGE = typing___cast('ObservationTypeProto', 3)
class ObservationProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
class FloatData(google___protobuf___message___Message):

compressed_data = ... # type: builtin___bytes
compressed_channel_mapping = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
dimension_properties = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
observation_type = ... # type: ObservationTypeProto
@property
def float_data(self) -> ObservationProto.FloatData: ...

float_data : typing___Optional[ObservationProto.FloatData] = None,
compressed_channel_mapping : typing___Optional[typing___Iterable[builtin___int]] = None,
dimension_properties : typing___Optional[typing___Iterable[builtin___int]] = None,
observation_type : typing___Optional[ObservationTypeProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> ObservationProto: ...

def HasField(self, field_name: typing_extensions___Literal[u"compressed_data",u"float_data",u"observation_data"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"dimension_properties",u"float_data",u"observation_data",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",u"compressed_data",u"compression_type",u"dimension_properties",u"float_data",u"observation_data",u"observation_type",u"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"dimension_properties",b"dimension_properties",u"float_data",b"float_data",u"observation_data",b"observation_data",u"shape",b"shape"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"compressed_channel_mapping",b"compressed_channel_mapping",u"compressed_data",b"compressed_data",u"compression_type",b"compression_type",u"dimension_properties",b"dimension_properties",u"float_data",b"float_data",u"observation_data",b"observation_data",u"observation_type",b"observation_type",u"shape",b"shape"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions___Literal[u"observation_data",b"observation_data"]) -> typing_extensions___Literal["compressed_data","float_data"]: ...

8
ml-agents-envs/mlagents_envs/env_utils.py


from mlagents_envs.exception import UnityEnvironmentException
logger = get_logger(__name__)
def get_platform():
"""
returns the platform of the operating system : linux, darwin or win32

.replace(".x86", "")
)
true_filename = os.path.basename(os.path.normpath(env_path))
get_logger(__name__).debug(f"The true file name is {true_filename}")
logger.debug(f"The true file name is {true_filename}")
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None

f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)
else:
get_logger(__name__).debug(f"This is the launch string {launch_string}")
logger.debug(f"The launch string is {launch_string}")
logger.debug(f"Running with args {args}")
# Launch Unity environment
subprocess_args = [launch_string] + args
try:

55
ml-agents-envs/mlagents_envs/environment.py


# * 1.1.0 - support concatenated PNGs for compressed observations.
# * 1.2.0 - support compression mapping for stacked compressed observations.
# * 1.3.0 - support action spaces with both continuous and discrete actions.
API_VERSION = "1.3.0"
# * 1.4.0 - support training analytics sent from python trainer to the editor.
API_VERSION = "1.4.0"
# Default port that the editor listens on. If an environment executable
# isn't specified, this port will be used.

capabilities.concatenatedPngObservations = True
capabilities.compressedChannelMapping = True
capabilities.hybridActions = True
capabilities.trainingAnalytics = True
return capabilities
@staticmethod

# If true, this means the environment was successfully loaded
self._loaded = False
# The process that is started. If None, no process was started
self._proc1 = None
self._process: Optional[subprocess.Popen] = None
self.academy_capabilities: UnityRLCapabilitiesProto = None # type: ignore
# If the environment name is None, a new environment will not be launched
# and the communicator will directly try to connect to an existing unity environment.

)
if file_name is not None:
try:
self._proc1 = env_utils.launch_executable(
self._process = env_utils.launch_executable(
file_name, self._executable_args()
)
except UnityEnvironmentException:

self._env_actions: Dict[str, ActionTuple] = {}
self._is_first_message = True
self._update_behavior_specs(aca_output)
self.academy_capabilities = aca_params.capabilities
@staticmethod
def _get_communicator(worker_id, base_port, timeout_wait):

if self._no_graphics:
args += ["-nographics", "-batchmode"]
args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)]
if self._log_folder:
# If the logfile arg isn't already set in the env args,
# try to set it to an output directory
logfile_set = "-logfile" in (arg.lower() for arg in self._additional_args)
if self._log_folder and not logfile_set:
log_file_path = os.path.join(
self._log_folder, f"Player-{self._worker_id}.log"
)

def reset(self) -> None:
if self._loaded:
outputs = self._communicator.exchange(self._generate_reset_input())
outputs = self._communicator.exchange(
self._generate_reset_input(), self._poll_process
)
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

].action_spec.empty_action(n_agents)
step_input = self._generate_step_input(self._env_actions)
with hierarchical_timer("communicator.exchange"):
outputs = self._communicator.exchange(step_input)
outputs = self._communicator.exchange(step_input, self._poll_process)
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

self._assert_behavior_exists(behavior_name)
return self._env_state[behavior_name]
def _poll_process(self) -> None:
"""
Check the status of the subprocess. If it has exited, raise a UnityEnvironmentException
:return: None
"""
if not self._process:
return
poll_res = self._process.poll()
if poll_res is not None:
exc_msg = self._returncode_to_env_message(self._process.returncode)
raise UnityEnvironmentException(exc_msg)
def close(self):
"""
Sends a shutdown signal to the unity environment, and closes the socket connection.

timeout = self._timeout_wait
self._loaded = False
self._communicator.close()
if self._proc1 is not None:
if self._process is not None:
self._proc1.wait(timeout=timeout)
signal_name = self._returncode_to_signal_name(self._proc1.returncode)
signal_name = f" ({signal_name})" if signal_name else ""
return_info = f"Environment shut down with return code {self._proc1.returncode}{signal_name}."
logger.info(return_info)
self._process.wait(timeout=timeout)
logger.info(self._returncode_to_env_message(self._process.returncode))
self._proc1.kill()
self._process.kill()
self._proc1 = None
self._process = None
@timed
def _generate_step_input(

) -> UnityOutputProto:
inputs = UnityInputProto()
inputs.rl_initialization_input.CopyFrom(init_parameters)
return self._communicator.initialize(inputs)
return self._communicator.initialize(inputs, self._poll_process)
@staticmethod
def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto:

except Exception:
# Should generally be a ValueError, but catch everything just in case.
return None
@staticmethod
def _returncode_to_env_message(returncode: int) -> str:
signal_name = UnityEnvironment._returncode_to_signal_name(returncode)
signal_name = f" ({signal_name})" if signal_name else ""
return f"Environment shut down with return code {returncode}{signal_name}."

12
ml-agents-envs/mlagents_envs/mock_communicator.py


from .communicator import Communicator
from typing import Optional
from .communicator import Communicator, PollCallback
from .environment import UnityEnvironment
from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto
from mlagents_envs.communicator_objects.brain_parameters_pb2 import (

self.brain_name = brain_name
self.vec_obs_size = vec_obs_size
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
def initialize(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
if self.is_discrete:
action_spec = ActionSpecProto(
num_discrete_actions=2, discrete_branch_sizes=[3, 2]

)
return dict_agent_info
def exchange(self, inputs: UnityInputProto) -> UnityOutputProto:
def exchange(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
result = UnityRLOutputProto(agentInfos=self._get_agent_infos())
return UnityOutputProto(rl_output=result)

47
ml-agents-envs/mlagents_envs/rpc_communicator.py


import grpc
from typing import Optional
from multiprocessing import Pipe
from multiprocessing import Pipe
import time
from .communicator import Communicator
from .communicator import Communicator, PollCallback
from mlagents_envs.communicator_objects.unity_to_external_pb2_grpc import (
UnityToExternalProtoServicer,
add_UnityToExternalProtoServicer_to_server,

finally:
s.close()
def poll_for_timeout(self):
def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None:
Additionally, a callback can be passed to periodically check the state of the environment.
This is used to detect the case when the environment dies without cleaning up the connection,
so that we can stop sooner and raise a more appropriate error.
if not self.unity_to_external.parent_conn.poll(self.timeout_wait):
raise UnityTimeOutException(
"The Unity environment took too long to respond. Make sure that :\n"
"\t The environment does not need user interaction to launch\n"
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n'
"\t The environment and the Python interface have compatible versions."
)
deadline = time.monotonic() + self.timeout_wait
callback_timeout_wait = self.timeout_wait // 10
while time.monotonic() < deadline:
if self.unity_to_external.parent_conn.poll(callback_timeout_wait):
# Got an acknowledgment from the connection
return
if poll_callback:
# Fire the callback - if it detects something wrong, it should raise an exception.
poll_callback()
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
self.poll_for_timeout()
# Got this far without reading any data from the connection, so it must be dead.
raise UnityTimeOutException(
"The Unity environment took too long to respond. Make sure that :\n"
"\t The environment does not need user interaction to launch\n"
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n'
"\t The environment and the Python interface have compatible versions."
)
def initialize(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
self.poll_for_timeout(poll_callback)
aca_param = self.unity_to_external.parent_conn.recv().unity_output
message = UnityMessageProto()
message.header.status = 200

return aca_param
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
def exchange(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> Optional[UnityOutputProto]:
self.poll_for_timeout()
self.poll_for_timeout(poll_callback)
output = self.unity_to_external.parent_conn.recv()
if output.header.status != 200:
return None

44
ml-agents-envs/mlagents_envs/rpc_utils.py


from mlagents_envs.base_env import (
ActionSpec,
SensorSpec,
ObservationSpec,
ObservationType,
)
from mlagents_envs.exception import UnityObservationException
from mlagents_envs.timers import hierarchical_timer, timed

:param agent_info: protobuf object.
:return: BehaviorSpec object.
"""
observation_shape = [tuple(obs.shape) for obs in agent_info.observations]
dim_props = [
tuple(DimensionProperty(dim) for dim in obs.dimension_properties)
for obs in agent_info.observations
]
sensor_specs = [
SensorSpec(obs_shape, dim_p)
for obs_shape, dim_p in zip(observation_shape, dim_props)
]
observation_specs = []
for obs in agent_info.observations:
observation_specs.append(
ObservationSpec(
tuple(obs.shape),
tuple(DimensionProperty(dim) for dim in obs.dimension_properties),
ObservationType(obs.observation_type),
)
)
# proto from communicator < v1.3 does not set action spec, use deprecated fields instead
if (
brain_param_proto.action_spec.num_continuous_actions == 0

action_spec_proto.num_continuous_actions,
tuple(branch for branch in action_spec_proto.discrete_branch_sizes),
)
return BehaviorSpec(sensor_specs, action_spec)
return BehaviorSpec(observation_specs, action_spec)
class OffsetBytesIO:

]
decision_obs_list: List[np.ndarray] = []
terminal_obs_list: List[np.ndarray] = []
for obs_index, sensor_specs in enumerate(behavior_spec.sensor_specs):
is_visual = len(sensor_specs.shape) == 3
for obs_index, observation_specs in enumerate(behavior_spec.observation_specs):
is_visual = len(observation_specs.shape) == 3
obs_shape = cast(Tuple[int, int, int], sensor_specs.shape)
obs_shape = cast(Tuple[int, int, int], observation_specs.shape)
decision_obs_list.append(
_process_visual_observation(
obs_index, obs_shape, decision_agent_info_list

else:
decision_obs_list.append(
_process_vector_observation(
obs_index, sensor_specs.shape, decision_agent_info_list
obs_index, observation_specs.shape, decision_agent_info_list
obs_index, sensor_specs.shape, terminal_agent_info_list
obs_index, observation_specs.shape, terminal_agent_info_list
)
)
decision_rewards = np.array(

if len(terminal_team_manager) == 0:
terminal_team_manager = None
decision_team_managers = [
agent_info.team_manager_id for agent_info in decision_agent_info_list
]
terminal_team_managers = [
agent_info.team_manager_id for agent_info in terminal_agent_info_list
]
_raise_on_nan_and_inf(decision_rewards, "rewards")
_raise_on_nan_and_inf(terminal_rewards, "rewards")

decision_rewards,
decision_agent_id,
action_mask,
decision_team_manager,
decision_team_managers,
),
TerminalSteps(
terminal_obs_list,

terminal_team_manager,
terminal_team_managers,
),
)

2
ml-agents-envs/mlagents_envs/side_channel/engine_configuration_channel.py


"""
raise UnityCommunicationException(
"The EngineConfigurationChannel received a message from Unity, "
+ "this should not have happend."
+ "this should not have happened."
)
def set_configuration_parameters(

2
ml-agents-envs/mlagents_envs/side_channel/environment_parameters_channel.py


def on_message_received(self, msg: IncomingMessage) -> None:
raise UnityCommunicationException(
"The EnvironmentParametersChannel received a message from Unity, "
+ "this should not have happend."
+ "this should not have happened."
)
def set_float_parameter(self, key: str, value: float) -> None:

14
ml-agents-envs/mlagents_envs/tests/test_envs.py


env.close()
assert isinstance(decision_steps, DecisionSteps)
assert isinstance(terminal_steps, TerminalSteps)
assert len(spec.sensor_specs) == len(decision_steps.obs)
assert len(spec.sensor_specs) == len(terminal_steps.obs)
assert len(spec.observation_specs) == len(decision_steps.obs)
assert len(spec.observation_specs) == len(terminal_steps.obs)
for sen_spec, obs in zip(spec.sensor_specs, decision_steps.obs):
for sen_spec, obs in zip(spec.observation_specs, decision_steps.obs):
for sen_spec, obs in zip(spec.sensor_specs, terminal_steps.obs):
for sen_spec, obs in zip(spec.observation_specs, terminal_steps.obs):
assert (n_agents,) + sen_spec.shape == obs.shape

env.close()
assert isinstance(decision_steps, DecisionSteps)
assert isinstance(terminal_steps, TerminalSteps)
assert len(spec.sensor_specs) == len(decision_steps.obs)
assert len(spec.sensor_specs) == len(terminal_steps.obs)
for spec, obs in zip(spec.sensor_specs, decision_steps.obs):
assert len(spec.observation_specs) == len(decision_steps.obs)
assert len(spec.observation_specs) == len(terminal_steps.obs)
for spec, obs in zip(spec.observation_specs, decision_steps.obs):
assert (n_agents,) + spec.shape == obs.shape
assert 0 in decision_steps
assert 2 in terminal_steps

54
ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py


import pytest
from unittest import mock
import grpc
import mlagents_envs.rpc_communicator
from mlagents_envs.exception import UnityWorkerInUseException
from mlagents_envs.exception import (
UnityWorkerInUseException,
UnityTimeOutException,
UnityEnvironmentException,
)
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
def test_rpc_communicator_checks_port_on_create():

second_comm = RpcCommunicator(worker_id=1)
first_comm.close()
second_comm.close()
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_OK(mock_impl, mock_grpc_server):
comm = RpcCommunicator(timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = True
input = UnityInputProto()
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_timeout(mock_impl, mock_grpc_server):
comm = RpcCommunicator(timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityTimeOutException):
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_callback(mock_impl, mock_grpc_server):
def callback():
raise UnityEnvironmentException
comm = RpcCommunicator(timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityEnvironmentException):
comm.initialize(input, poll_callback=callback)
comm.unity_to_external.parent_conn.poll.assert_called()

19
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


steps_from_proto,
)
from PIL import Image
from mlagents.trainers.tests.dummy_config import create_sensor_specs_with_shapes
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes
def generate_list_agent_proto(

n_agents = 10
shapes = [(3,), (4,)]
spec = BehaviorSpec(
create_sensor_specs_with_shapes(shapes), ActionSpec.create_continuous(3)
create_observation_specs_with_shapes(shapes), ActionSpec.create_continuous(3)
)
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(
create_sensor_specs_with_shapes(shapes), ActionSpec.create_discrete((7, 3))
create_observation_specs_with_shapes(shapes), ActionSpec.create_discrete((7, 3))
)
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(
create_sensor_specs_with_shapes(shapes), ActionSpec.create_discrete((10,))
create_observation_specs_with_shapes(shapes), ActionSpec.create_discrete((10,))
)
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(
create_sensor_specs_with_shapes(shapes), ActionSpec.create_discrete((2, 2, 6))
create_observation_specs_with_shapes(shapes),
ActionSpec.create_discrete((2, 2, 6)),
)
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(
create_sensor_specs_with_shapes(shapes), ActionSpec.create_continuous(10)
create_observation_specs_with_shapes(shapes), ActionSpec.create_continuous(10)
)
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)

behavior_spec = behavior_spec_from_proto(bp, agent_proto)
assert behavior_spec.action_spec.is_discrete()
assert not behavior_spec.action_spec.is_continuous()
assert [spec.shape for spec in behavior_spec.sensor_specs] == [(3,), (4,)]
assert [spec.shape for spec in behavior_spec.observation_specs] == [(3,), (4,)]
assert behavior_spec.action_spec.discrete_branches == (5, 4)
assert behavior_spec.action_spec.discrete_size == 2
bp = BrainParametersProto()

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(
create_sensor_specs_with_shapes(shapes), ActionSpec.create_continuous(3)
create_observation_specs_with_shapes(shapes), ActionSpec.create_continuous(3)
)
ap_list = generate_list_agent_proto(n_agents, shapes, infinite_rewards=True)
with pytest.raises(RuntimeError):

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(
create_sensor_specs_with_shapes(shapes), ActionSpec.create_continuous(3)
create_observation_specs_with_shapes(shapes), ActionSpec.create_continuous(3)
)
ap_list = generate_list_agent_proto(n_agents, shapes, nan_observations=True)
with pytest.raises(RuntimeError):

6
ml-agents-envs/mlagents_envs/tests/test_steps.py


ActionSpec,
BehaviorSpec,
)
from mlagents.trainers.tests.dummy_config import create_sensor_specs_with_shapes
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes
def test_decision_steps():

def test_empty_decision_steps():
specs = BehaviorSpec(
sensor_specs=create_sensor_specs_with_shapes([(3, 2), (5,)]),
observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]),
action_spec=ActionSpec.create_continuous(3),
)
ds = DecisionSteps.empty(specs)

def test_empty_terminal_steps():
specs = BehaviorSpec(
sensor_specs=create_sensor_specs_with_shapes([(3, 2), (5,)]),
observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]),
action_spec=ActionSpec.create_continuous(3),
)
ts = TerminalSteps.empty(specs)

1
ml-agents/mlagents/torch_utils/__init__.py


from mlagents.torch_utils.torch import torch as torch # noqa
from mlagents.torch_utils.torch import nn # noqa
from mlagents.torch_utils.torch import set_torch_config # noqa
from mlagents.torch_utils.torch import default_device # noqa

37
ml-agents/mlagents/torch_utils/torch.py


from distutils.version import LooseVersion
import pkg_resources
from mlagents.torch_utils import cpu_utils
from mlagents.trainers.settings import TorchSettings
from mlagents_envs.logging_util import get_logger
logger = get_logger(__name__)
def assert_torch_installed():

torch.set_num_threads(cpu_utils.get_num_threads_to_use())
os.environ["KMP_BLOCKTIME"] = "0"
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device = torch.device("cuda")
else:
torch.set_default_tensor_type(torch.FloatTensor)
device = torch.device("cpu")
_device = torch.device("cpu")
def set_torch_config(torch_settings: TorchSettings) -> None:
global _device
if torch_settings.device is None:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
else:
device_str = torch_settings.device
_device = torch.device(device_str)
if _device.type == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
logger.info(f"default Torch device: {_device}")
# Initialize to default settings
set_torch_config(TorchSettings(device=None))
return device
return _device

15
ml-agents/mlagents/trainers/cli_utils.py


action=DetectDefault,
)
argparser.add_argument(
"--cpu",
default=False,
action=DetectDefaultStoreTrue,
help="Forces training using CPU only",
)
argparser.add_argument(
"--torch",
default=False,
action=RaiseRemovedWarning,

action=DetectDefaultStoreTrue,
help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
"the graphics driver. Use this only if your agents don't use visual observations.",
)
torch_conf = argparser.add_argument_group(title="Torch Configuration")
torch_conf.add_argument(
"--torch-device",
default=None,
dest="device",
action=DetectDefault,
help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"',
)
return argparser

9
ml-agents/mlagents/trainers/demo_loader.py


)
)
# check observations match
if len(behavior_spec.sensor_specs) != len(expected_behavior_spec.sensor_specs):
if len(behavior_spec.observation_specs) != len(
expected_behavior_spec.observation_specs
):
zip(behavior_spec.sensor_specs, expected_behavior_spec.sensor_specs)
zip(
behavior_spec.observation_specs,
expected_behavior_spec.observation_specs,
)
):
if demo_obs.shape != policy_obs.shape:
raise RuntimeError(

12
ml-agents/mlagents/trainers/env_manager.py


from mlagents.trainers.policy import Policy
from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.settings import TrainerSettings
from mlagents_envs.logging_util import get_logger
AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]]

Sends environment parameter settings to C# via the
EnvironmentParametersSideChannel.
:param config: Dict of environment parameter keys and values
"""
pass
def on_training_started(
self, behavior_name: str, trainer_settings: TrainerSettings
) -> None:
"""
Handle traing starting for a new behavior type. Generally nothing is necessary here.
:param behavior_name:
:param trainer_settings:
:return:
"""
pass

15
ml-agents/mlagents/trainers/learn.py


from mlagents_envs.base_env import BaseEnv
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.timers import (
hierarchical_timer,
get_timer_tree,

:param run_options: Command line arguments for training.
"""
with hierarchical_timer("run_training.setup"):
torch_utils.set_torch_config(options.torch_settings)
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings

env_settings.env_args,
os.path.abspath(run_logs_dir), # Unity environment requires absolute path
)
engine_config = EngineConfig(
width=engine_settings.width,
height=engine_settings.height,
quality_level=engine_settings.quality_level,
time_scale=engine_settings.time_scale,
target_frame_rate=engine_settings.target_frame_rate,
capture_frame_rate=engine_settings.capture_frame_rate,
)
env_manager = SubprocessEnvManager(
env_factory, engine_config, env_settings.num_envs
)
env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs)
env_parameter_manager = EnvironmentParameterManager(
options.environment_parameters, run_seed, restore=checkpoint_settings.resume
)

4
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


done: bool,
all_dones: bool,
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
n_obs = len(self.policy.behavior_spec.sensor_specs)
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
team_obs = TeamObsUtil.from_buffer(batch, n_obs)
# next_obs = ObsUtil.from_buffer_next(batch, n_obs)

部分文件因为文件数量过多而无法显示

正在加载...
取消
保存