浏览代码

Merge remote-tracking branch 'origin/master' into develop-add-fire

/develop/add-fire
Arthur Juliani 4 年前
当前提交
28e095e0
共有 64 个文件被更改,包括 1751 次插入427 次删除
  1. 1
      .github/ISSUE_TEMPLATE/bug_report.md
  2. 23
      .pre-commit-config.yaml
  3. 35
      README.md
  4. 4
      com.unity.ml-agents/CHANGELOG.md
  5. 23
      com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
  6. 20
      com.unity.ml-agents/Editor/BrainParametersDrawer.cs
  7. 18
      com.unity.ml-agents/Runtime/Agent.cs
  8. 43
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  9. 43
      com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
  10. 6
      com.unity.ml-agents/Runtime/Policies/BrainParameters.cs
  11. 24
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  12. 65
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  13. 2
      com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef
  14. 10
      com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
  15. 5
      docs/Learning-Environment-Create-New.md
  16. 2
      docs/Migrating.md
  17. 17
      docs/Python-API.md
  18. 2
      docs/Training-ML-Agents.md
  19. 2
      docs/Using-Tensorboard.md
  20. 21
      gym-unity/gym_unity/envs/__init__.py
  21. 5
      gym-unity/gym_unity/tests/test_gym.py
  22. 50
      ml-agents-envs/mlagents_envs/base_env.py
  23. 293
      ml-agents-envs/mlagents_envs/environment.py
  24. 54
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  25. 47
      ml-agents-envs/mlagents_envs/tests/test_side_channel.py
  26. 14
      ml-agents/mlagents/trainers/learn.py
  27. 1
      ml-agents/mlagents/trainers/policy/tf_policy.py
  28. 4
      ml-agents/mlagents/trainers/ppo/trainer.py
  29. 8
      ml-agents/mlagents/trainers/simple_env_manager.py
  30. 10
      ml-agents/mlagents/trainers/subprocess_env_manager.py
  31. 12
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  32. 5
      ml-agents/mlagents/trainers/tests/test_learn.py
  33. 4
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  34. 6
      ml-agents/tests/yamato/check_coverage_percent.py
  35. 4
      ml-agents/tests/yamato/scripts/run_llapi.py
  36. 2
      ml-agents/tests/yamato/yamato_utils.py
  37. 6
      utils/validate_versions.py
  38. 8
      com.unity.ml-agents/Runtime/Sensors/Reflection.meta
  39. 11
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta
  40. 292
      com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs
  41. 11
      com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta
  42. 108
      ml-agents-envs/mlagents_envs/env_utils.py
  43. 81
      ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py
  44. 64
      ml-agents-envs/mlagents_envs/tests/test_env_utils.py
  45. 102
      ml-agents-envs/mlagents_envs/tests/test_steps.py
  46. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta
  47. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta
  48. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta
  49. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta
  50. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta
  51. 97
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  52. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta
  53. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta
  54. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta
  55. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta
  56. 19
      com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs
  57. 19
      com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs
  58. 19
      com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs
  59. 272
      com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs
  60. 22
      com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs
  61. 20
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs
  62. 21
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs
  63. 22
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs

1
.github/ISSUE_TEMPLATE/bug_report.md


If applicable, add screenshots to help explain your problem.
**Environment (please complete the following information):**
- Unity Version: [e.g. Unity 2020.1f1]
- OS + version: [e.g. Windows 10]
- _ML-Agents version_: (e.g. ML-Agents v0.8, or latest `develop` branch from source)
- _TensorFlow version_: (you can run `pip3 show tensorflow` to get this)

23
.pre-commit-config.yaml


files: "gym-unity/.*"
args: [--ignore-missing-imports, --disallow-incomplete-defs]
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.1
hooks:
- id: flake8
exclude: >
(?x)^(
.*_pb2.py|
.*_pb2_grpc.py
)$
# flake8-tidy-imports is used for banned-modules, not actually tidying
additional_dependencies: [flake8-comprehensions==3.2.2, flake8-tidy-imports==4.1.0, flake8-bugbear==20.1.4]
rev: v2.4.0
rev: v2.5.0
hooks:
- id: mixed-line-ending
exclude: >

.*.meta
)$
args: [--fix=lf]
- id: flake8
exclude: >
(?x)^(
.*_pb2.py|
.*_pb2_grpc.py
)$
# flake8-tidy-imports is used for banned-modules, not actually tidying
additional_dependencies: [flake8-comprehensions==3.1.4, flake8-tidy-imports==4.0.0, flake8-bugbear==20.1.2]
- id: trailing-whitespace
name: trailing-whitespace-markdown
types: [markdown]

35
README.md


## Releases & Documentation
**Our latest, stable release is `Release 1`. Click [here](docs/Readme.md) to
get started with the latest release of ML-Agents.**
**Our latest, stable release is `Release 1`. Click
[here](https://github.com/Unity-Technologies/ml-agents/tree/release_1/docs/Readme.md)
to get started with the latest release of ML-Agents.**
The table below lists all our releases, including our `master` branch which is under active
development and may be unstable. A few helpful guidelines:
* The docs links in the table below include installation and usage instructions specific to each
release. Remember to always use the documentation that corresponds to the release version you're
using.
* See the [GitHub releases](https://github.com/Unity-Technologies/ml-agents/releases) for more
details of the changes between versions.
* If you have used an earlier version of the ML-Agents Toolkit, we strongly recommend our
[guide on migrating from earlier versions](docs/Migrating.md).
The table below lists all our releases, including our `master` branch which is
under active development and may be unstable. A few helpful guidelines:
- The [Versioning page](docs/Versioning.md) overviews how we manage our GitHub
releases and the versioning process for each of the ML-Agents components.
- The [Releases page](https://github.com/Unity-Technologies/ml-agents/releases)
contains details of the changes between releases.
- The [Migration page](docs/Migrating.md) contains details on how to upgrade
from earlier releases of the ML-Agents Toolkit.
- The **Documentation** links in the table below include installation and usage
instructions specific to each release. Remember to always use the
documentation that corresponds to the release version you're using.
| **Version** | **Release Date** | **Source** | **Documentation** | **Download** |
|:-------:|:------:|:-------------:|:-------:|:------------:|

If you use Unity or the ML-Agents Toolkit to conduct research, we ask that you
cite the following paper as a reference:
Juliani, A., Berges, V., Vckay, E., Gao, Y., Henry, H., Mattar, M., Lange, D.
(2018). Unity: A General Platform for Intelligent Agents. _arXiv preprint
arXiv:1809.02627._ https://github.com/Unity-Technologies/ml-agents.
Juliani, A., Berges, V., Teng, E., Cohen, A., Harper, J., Elion, C., Goy, C.,
Gao, Y., Henry, H., Mattar, M., Lange, D. (2020). Unity: A General Platform for
Intelligent Agents. _arXiv preprint
[arXiv:1809.02627](https://arxiv.org/abs/1809.02627)._
https://github.com/Unity-Technologies/ml-agents.
- (May 12, 2020)
[Announcing ML-Agents Unity Package v1.0!](https://blogs.unity3d.com/2020/05/12/announcing-ml-agents-unity-package-v1-0/)
- (February 28, 2020)
[Training intelligent adversaries using self-play with ML-Agents](https://blogs.unity3d.com/2020/02/28/training-intelligent-adversaries-using-self-play-with-ml-agents/)
- (November 11, 2019)

4
com.unity.ml-agents/CHANGELOG.md


#### ml-agents / ml-agents-envs / gym-unity (Python)
- `max_step` in the `TerminalStep` and `TerminalSteps` objects was renamed `interrupted`.
- `beta` and `epsilon` in `PPO` are no longer decayed by default but follow the same schedule as learning rate. (#3940)
- `get_behavior_names()` and `get_behavior_spec()` on UnityEnvironment were replaced by the `behavior_specs` property. (#3946)
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
observations via reflection.
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Curriculum and Parameter Randomization configurations have been merged
into the main training configuration file. Note that this means training

- Unity Player logs are now written out to the results directory. (#3877)
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
### Bug Fixes
- An issue was fixed where using `--initialize-from` would resume from the past step count. (#3962)
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)

23
com.unity.ml-agents/Editor/BehaviorParametersEditor.cs


using System.Collections.Generic;
using Unity.MLAgents.Sensors.Reflection;
using UnityEngine;
namespace Unity.MLAgents.Editor

EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
EditorGUILayout.PropertyField(so.FindProperty("m_UseChildSensors"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservableAttributeHandling"), true);
}
EditorGUI.EndDisabledGroup();

Model barracudaModel = null;
var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue;
var behaviorParameters = (BehaviorParameters)target;
// Grab the sensor components, since we need them to determine the observation sizes.
SensorComponent[] sensorComponents;
if (behaviorParameters.UseChildSensors)
{

{
sensorComponents = behaviorParameters.GetComponents<SensorComponent>();
}
// Get the total size of the sensors generated by ObservableAttributes.
// If there are any errors (e.g. unsupported type, write-only properties), display them too.
int observableAttributeSensorTotalSize = 0;
var agent = behaviorParameters.GetComponent<Agent>();
if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
List<string> observableErrors = new List<string>();
observableAttributeSensorTotalSize = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors);
foreach (var check in observableErrors)
{
EditorGUILayout.HelpBox(check, MessageType.Warning);
}
}
var brainParameters = behaviorParameters.BrainParameters;
if (model != null)
{

{
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
barracudaModel, brainParameters, sensorComponents, behaviorParameters.BehaviorType
barracudaModel, brainParameters, sensorComponents,
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
);
foreach (var check in failedChecks)
{

20
com.unity.ml-agents/Editor/BrainParametersDrawer.cs


static void DrawContinuousVectorAction(Rect position, SerializedProperty property)
{
var vecActionSize = property.FindPropertyRelative(k_ActionSizePropName);
vecActionSize.arraySize = 1;
// This check is here due to:
// https://fogbugz.unity3d.com/f/cases/1246524/
// If this case has been resolved, please remove this if condition.
if (vecActionSize.arraySize != 1)
{
vecActionSize.arraySize = 1;
}
var continuousActionSize =
vecActionSize.GetArrayElementAtIndex(0);
EditorGUI.PropertyField(

static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
{
var vecActionSize = property.FindPropertyRelative(k_ActionSizePropName);
vecActionSize.arraySize = EditorGUI.IntField(
var newSize = EditorGUI.IntField(
// This check is here due to:
// https://fogbugz.unity3d.com/f/cases/1246524/
// If this case has been resolved, please remove this if condition.
if (newSize != vecActionSize.arraySize)
{
vecActionSize.arraySize = newSize;
}
position.y += k_LineHeight;
position.x += 20;
position.width -= 20;

18
com.unity.ml-agents/Runtime/Agent.cs


using UnityEngine;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
using UnityEngine.Serialization;

m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic);
ResetData();
Initialize();
InitializeSensors();
using (TimerStack.Instance.Scoped("InitializeSensors"))
{
InitializeSensors();
}
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

/// </summary>
internal void InitializeSensors()
{
if (m_PolicyFactory.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
var excludeInherited =
m_PolicyFactory.ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited;
using (TimerStack.Instance.Scoped("CreateObservableSensors"))
{
var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited);
sensors.AddRange(observableSensors);
}
}
// Get all attached sensor components
SensorComponent[] attachedSensorComponents;
if (m_PolicyFactory.UseChildSensors)

43
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensor components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
SensorComponent[] sensorComponents, BehaviorType behaviorType = BehaviorType.Default)
SensorComponent[] sensorComponents, int observableAttributeTotalSize = 0,
BehaviorType behaviorType = BehaviorType.Default)
{
List<string> failedModelChecks = new List<string>();
if (model == null)

CheckOutputTensorPresence(model, memorySize))
;
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensorComponents)
CheckInputTensorShape(model, brainParameters, sensorComponents, observableAttributeTotalSize)
);
failedModelChecks.AddRange(
CheckOutputTensorShape(model, brainParameters, isContinuous, actionSize)

/// Whether the model is expecting continuous or discrete control.
/// </param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="observableAttributeTotalSize">Total size of ObservableAttributes</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed input presence checks.
/// </returns>

/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents)
Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents,
int observableAttributeTotalSize)
new Dictionary<string, Func<BrainParameters, TensorProxy, SensorComponent[], string>>()
new Dictionary<string, Func<BrainParameters, TensorProxy, SensorComponent[], int, string>>()
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)},
tensorTester[mem.input] = ((bp, tensor, scs) => null);
tensorTester[mem.input] = ((bp, tensor, scs, i) => null);
}
var visObsIndex = 0;

continue;
}
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs) => CheckVisualObsShape(tensor, sensorComponent);
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}

else
{
var tester = tensorTester[tensor.name];
var error = tester.Invoke(brainParameters, tensor, sensorComponents);
var error = tester.Invoke(brainParameters, tensor, sensorComponents, observableAttributeTotalSize);
if (error != null)
{
failedModelChecks.Add(error);

/// </param>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents,
int observableAttributeTotalSize)
{
var vecObsSizeBp = brainParameters.VectorObservationSize;
var numStackedVector = brainParameters.NumStackedVectorObservations;

totalVectorSensorSize += sensorComp.GetObservationShape()[0];
}
}
totalVectorSensorSize += observableAttributeTotalSize;
if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT)
{

sensorSizes += "]";
return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
$"but received {vecObsSizeBp} x {numStackedVector} vector observations and " +
$"but received: \n" +
$"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" +
$"Total [Observable] attributes: {observableAttributeTotalSize}\n" +
$"SensorComponent sizes: {sensorSizes}.";
}
return null;

/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="tensorProxy"> The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="sensorComponents">Array of attached sensor components (unused).</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes (unused).</param>
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
BrainParameters brainParameters, TensorProxy tensorProxy,
SensorComponent[] sensorComponents, int observableAttributeTotalSize)
{
var numberActionsBp = brainParameters.VectorActionSize.Length;
var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1];

43
com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs


using System;
using UnityEngine;
using UnityEngine.Serialization;
using Unity.MLAgents.Sensors.Reflection;
namespace Unity.MLAgents.Policies
{

/// neural network model.
/// </summary>
InferenceOnly
}
/// <summary>
/// Options for controlling how the Agent class is searched for <see cref="ObservableAttribute"/>s.
/// </summary>
public enum ObservableAttributeOptions
{
/// <summary>
/// All ObservableAttributes on the Agent will be ignored. If there are no
/// ObservableAttributes on the Agent, this will result in the fastest
/// initialization time.
/// </summary>
Ignore,
/// <summary>
/// Only members on the declared class will be examined; members that are
/// inherited are ignored. This is the default behavior, and a reasonable
/// tradeoff between performance and flexibility.
/// </summary>
/// <remarks>This corresponds to setting the
/// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1)
/// when examining the fields and properties of the Agent class instance.
/// </remarks>
ExcludeInherited,
/// <summary>
/// All members on the class will be examined. This can lead to slower
/// startup times
/// </summary>
ExamineAll
}
/// <summary>

{
get { return m_UseChildSensors; }
set { m_UseChildSensors = value; }
}
[HideInInspector, SerializeField]
ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore;
/// <summary>
/// Determines how the Agent class is searched for <see cref="ObservableAttribute"/>s.
/// </summary>
public ObservableAttributeOptions ObservableAttributeHandling
{
get { return m_ObservableAttributeHandling; }
set { m_ObservableAttributeHandling = value; }
}
/// <summary>

6
com.unity.ml-agents/Runtime/Policies/BrainParameters.cs


public class BrainParameters
{
/// <summary>
/// The size of the observation space.
/// </summary>
/// <remarks>An agent creates the observation vector in its
/// The number of the observations that are added in
/// implementation.</remarks>
/// </summary>
/// <value>
/// The length of the vector containing observation values.
/// </value>

24
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs


using UnityEngine;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors.Reflection;
namespace Unity.MLAgents.Tests
{

}
}
static List<TestAgent> GetFakeAgents()
static List<TestAgent> GetFakeAgents(ObservableAttributeOptions observableAttributeOptions = ObservableAttributeOptions.Ignore)
bpA.ObservableAttributeHandling = observableAttributeOptions;
var agentA = goA.AddComponent<TestAgent>();
var goB = new GameObject("goB");

bpB.ObservableAttributeHandling = observableAttributeOptions;
var agentB = goB.AddComponent<TestAgent>();
var agents = new List<TestAgent> { agentA, agentB };

{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 3 }
shape = new long[] { 2, 4 }
var agentInfos = GetFakeAgents();
var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll);
generator.AddSensorIndex(0);
generator.AddSensorIndex(1);
generator.AddSensorIndex(2);
generator.AddSensorIndex(0); // ObservableAttribute (size 1)
generator.AddSensorIndex(1); // TestSensor (size 0)
generator.AddSensorIndex(2); // TestSensor (size 0)
generator.AddSensorIndex(3); // VectorSensor (size 3)
var agent0 = agentInfos[0];
var agent1 = agentInfos[1];
var inputs = new List<AgentInfoSensorsPair>

};
generator.Generate(inputTensor, batchSize, inputs);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Assert.AreEqual(inputTensor.data[0, 2], 3);
Assert.AreEqual(inputTensor.data[1, 0], 4);
Assert.AreEqual(inputTensor.data[1, 2], 6);
Assert.AreEqual(inputTensor.data[0, 1], 1);
Assert.AreEqual(inputTensor.data[0, 3], 3);
Assert.AreEqual(inputTensor.data[1, 1], 4);
Assert.AreEqual(inputTensor.data[1, 3], 6);
alloc.Dispose();
}

65
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


using System.Reflection;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Policies;
using Unity.MLAgents.SideChannels;

public int heuristicCalls;
public TestSensor sensor1;
public TestSensor sensor2;
[Observable("observableFloat")]
public float observableFloat;
public override void Initialize()
{

var agentGo1 = new GameObject("TestAgent");
agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var bp1 = agentGo1.GetComponent<BehaviorParameters>();
bp1.ObservableAttributeHandling = ObservableAttributeOptions.ExcludeInherited;
var agentGo2 = new GameObject("TestAgent");
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();

Assert.AreEqual(0, agent2.agentActionCalls);
// Make sure the Sensors were sorted
Assert.AreEqual(agent1.sensors[0].GetName(), "testsensor1");
Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor2");
Assert.AreEqual(agent1.sensors[0].GetName(), "observableFloat");
Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor1");
Assert.AreEqual(agent1.sensors[2].GetName(), "testsensor2");
// agent2 should only have two sensors (no observableFloat)
Assert.AreEqual(agent2.sensors[0].GetName(), "testsensor1");
Assert.AreEqual(agent2.sensors[1].GetName(), "testsensor2");
}
}

public void TestAgentDontCallBaseOnEnable()
{
_InnerAgentTestOnEnableOverride();
}
}
[TestFixture]
public class ObservableAttributeBehaviorTests
{
public class BaseObservableAgent : Agent
{
[Observable]
public float BaseField;
}
public class DerivedObservableAgent : BaseObservableAgent
{
[Observable]
public float DerivedField;
}
[Test]
public void TestObservableAttributeBehaviorIgnore()
{
var variants = new[]
{
// No observables found
(ObservableAttributeOptions.Ignore, 0),
// Only DerivedField found
(ObservableAttributeOptions.ExcludeInherited, 1),
// DerivedField and BaseField found
(ObservableAttributeOptions.ExamineAll, 2)
};
foreach (var(behavior, expectedNumSensors) in variants)
{
var go = new GameObject();
var agent = go.AddComponent<DerivedObservableAgent>();
var bp = go.GetComponent<BehaviorParameters>();
bp.ObservableAttributeHandling = behavior;
agent.LazyInitialize();
int numAttributeSensors = 0;
foreach (var sensor in agent.sensors)
{
if (sensor.GetType() != typeof(VectorSensor))
{
numAttributeSensors++;
}
}
Assert.AreEqual(expectedNumSensors, numAttributeSensors);
}
}
}
}

2
com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef


"references": [
"Unity.ML-Agents.Editor",
"Unity.ML-Agents",
"Barracuda",
"Unity.Barracuda",
"Unity.ML-Agents.CommunicatorObjects"
],
"optionalUnityReferences": [

10
com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs


#if UNITY_INCLUDE_TESTS
#if UNITY_INCLUDE_TESTS
using Unity.MLAgents.Sensors.Reflection;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;

[Observable]
public float ObservableFloat;
public override void Heuristic(float[] actionsOut)
{
numHeuristicCalls++;

public override int[] GetObservationShape()
{
int[] shape = (int[]) wrappedComponent.GetObservationShape().Clone();
int[] shape = (int[])wrappedComponent.GetObservationShape().Clone();
for (var i = 0; i < shape.Length; i++)
{
shape[i] *= numStacks;

behaviorParams.BehaviorName = "TestBehavior";
behaviorParams.TeamId = 42;
behaviorParams.UseChildSensors = true;
behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll;
// Can't actually create an Agent with InferenceOnly and no model, so change back

5
docs/Learning-Environment-Create-New.md


learning_rate: 3.0e-4
learning_rate_schedule: linear
max_steps: 5.0e4
memory_size: 128
normalize: false
num_epoch: 3
num_layers: 2

reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
strength: 1.0
gamma: 0.99
```
Since this example creates a very simple training environment with only a few

2
docs/Migrating.md


- Trainer configuration, curriculum configuration, and parameter randomization
configuration have all been moved to a single YAML file. (#3791)
- `max_step` in the `TerminalStep` and `TerminalSteps` objects was renamed `interrupted`.
- On the UnityEnvironment API, `get_behavior_names()` and `get_behavior_specs()` methods were combined into the property `behavior_specs` that contains a mapping from behavior names to behavior spec.
### Steps to Migrate
- Before upgrading, copy your `Behavior Name` sections from `trainer_config.yaml` into

the contents of the sampler config to `parameter_randomization` in the main trainer configuration.
- If you are using `UnityEnvironment` directly, replace `max_step` with `interrupted`
in the `TerminalStep` and `TerminalSteps` objects.
- Replace usage of `get_behavior_names()` and `get_behavior_specs()` in UnityEnvironment with `behavior_specs`.
## Migrating from 0.15 to Release 1

17
docs/Python-API.md


env = UnityEnvironment(file_name="3DBall", seed=1, side_channels=[])
# Start interacting with the evironment.
env.reset()
behavior_names = env.get_behavior_names()
behavior_names = env.behavior_spec.keys()
...
```
**NOTE:** Please read [Interacting with a Unity Environment](#interacting-with-a-unity-environment)

act.
- **Close : `env.close()`** Sends a shutdown signal to the environment and
terminates the communication.
- **Get Behavior Names : `env.get_behavior_names()`** Returns a list of
`BehaviorName`. Note that the number of groups can change over time in the
simulation if new Agent behaviors are created in the simulation.
- **Get Behavior Spec : `env.get_behavior_spec(behavior_name: str)`** Returns
the `BehaviorSpec` corresponding to the behavior_name given as input. A
`BehaviorSpec` contains information such as the observation shapes, the action
type (multi-discrete or continuous) and the action shape. Note that the
`BehaviorSpec` for a specific group is fixed throughout the simulation.
- **Behavior Specs : `env.behavior_specs`** Returns a Mapping of
`BehaviorName` to `BehaviorSpec` objects (read only).
A `BehaviorSpec` contains information such as the observation shapes, the
action type (multi-discrete or continuous) and the action shape. Note that
the `BehaviorSpec` for a specific group is fixed throughout the simulation.
The number of entries in the Mapping can change over time in the simulation
if new Agent behaviors are created in the simulation.
- **Get Steps : `env.get_steps(behavior_name: str)`** Returns a tuple
`DecisionSteps, TerminalSteps` corresponding to the behavior_name given as
input. The `DecisionSteps` contains information about the state of the agents

2
docs/Training-ML-Agents.md


#### Training with a Curriculum
Once we have specified our metacurriculum and curricula, we can launch
`mlagents-learn` using the `–curriculum` flag to point to the config file for
`mlagents-learn` using the config file for
our curricula and PPO will train using Curriculum Learning. For example, to
train agents in the Wall Jump environment with curriculum learning, we can run:

2
docs/Using-Tensorboard.md


```csharp
var statsRecorder = Academy.Instance.StatsRecorder;
statsSideChannel.Add("MyMetric", 1.0);
statsRecorder.Add("MyMetric", 1.0);
```

21
gym-unity/gym_unity/envs/__init__.py


self._env = unity_env
# Take a single step so that the brain information will be sent over
if not self._env.get_behavior_names():
if not self._env.behavior_specs:
self._n_agents = -1
# Save the step result from the last time all Agents requested decisions.
self._previous_decision_step: DecisionSteps = None

self._allow_multiple_visual_obs = allow_multiple_visual_obs
# Check brain configuration
if len(self._env.get_behavior_names()) != 1:
if len(self._env.behavior_specs) != 1:
self.name = self._env.get_behavior_names()[0]
self.group_spec = self._env.get_behavior_spec(self.name)
self.name = list(self._env.behavior_specs.keys())[0]
self.group_spec = self._env.behavior_specs[self.name]
if use_visual and self._get_n_vis_obs() == 0:
raise UnityGymException(

self._env.step()
decision_step, terminal_step = self._env.get_steps(self.name)
self._check_agents(max(len(decision_step), len(terminal_step)))
if len(terminal_step) != 0:
# The agent is done
self.game_over = True

logger.warning("Could not seed environment %s", self.name)
return
def _check_agents(self, n_agents: int) -> None:
if self._n_agents > 1:
@staticmethod
def _check_agents(n_agents: int) -> None:
if n_agents > 1:
"There can only be one Agent in the environment but {n_agents} were detected."
f"There can only be one Agent in the environment but {n_agents} were detected."
)
@property

@property
def observation_space(self):
return self._observation_space
@property
def number_agents(self):
return self._n_agents
class ActionFlattener:

5
gym-unity/gym_unity/tests/test_gym.py


ActionType,
DecisionSteps,
TerminalSteps,
BehaviorMapping,
)

setup_mock_unityenvironment(
mock_env, mock_spec, mock_decision_step, mock_terminal_step
)
env = UnityToGymWrapper(mock_env, use_visual=False)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.reset(), np.ndarray)

:Mock mock_decision: A DecisionSteps object that will be returned at each step and reset.
:Mock mock_termination: A TerminationSteps object that will be returned at each step and reset.
"""
mock_env.get_behavior_names.return_value = ["MockBrain"]
mock_env.get_behavior_spec.return_value = mock_spec
mock_env.behavior_specs = BehaviorMapping({"MockBrain": mock_spec})
mock_env.get_steps.return_value = (mock_decision, mock_termination)

50
ml-agents-envs/mlagents_envs/base_env.py


from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import List, NamedTuple, Tuple, Optional, Union, Dict, Iterator, Any
from typing import (
List,
NamedTuple,
Tuple,
Optional,
Union,
Dict,
Iterator,
Any,
Mapping as MappingType,
)
import numpy as np
from enum import Enum

return np.zeros((n_agents, self.action_size), dtype=np.float32)
class BehaviorMapping(Mapping):
def __init__(self, specs: Dict[BehaviorName, BehaviorSpec]):
self._dict = specs
def __len__(self) -> int:
return len(self._dict)
def __getitem__(self, behavior: BehaviorName) -> BehaviorSpec:
return self._dict[behavior]
def __iter__(self) -> Iterator[Any]:
yield from self._dict
class BaseEnv(ABC):
@abstractmethod
def step(self) -> None:

"""
pass
@abstractmethod
def reset(self) -> None:

pass
@abstractmethod
def close(self) -> None:

pass
@property
def get_behavior_names(self) -> List[BehaviorName]:
def behavior_specs(self) -> MappingType[str, BehaviorSpec]:
Returns the list of the behavior names present in the environment.
Returns a Mapping from behavior names to behavior specs.
This list can grow with time as new policies are instantiated.
:return: the list of agent BehaviorName.
Note that new keys can be added to this mapping as new policies are instantiated.
pass
@abstractmethod
def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None:

:param action: A two dimensional np.ndarray corresponding to the action
(either int or float)
"""
pass
@abstractmethod
def set_action_for_agent(

:param action: A one dimensional np.ndarray corresponding to the action
(either int or float)
"""
pass
@abstractmethod
def get_steps(

rewards, agent ids and interrupted flags of the agents that had their
episode terminated last step.
"""
pass
@abstractmethod
def get_behavior_spec(self, behavior_name: BehaviorName) -> BehaviorSpec:
"""
Get the BehaviorSpec corresponding to the behavior name
:param behavior_name: The name of the behavior the agents are part of
:return: A BehaviorSpec corresponding to that behavior
"""
pass

293
ml-agents-envs/mlagents_envs/environment.py


import atexit
from distutils.version import StrictVersion
import glob
import uuid
from typing import Dict, List, Optional, Any, Tuple
from typing import Dict, List, Optional, Tuple, Mapping as MappingType
from mlagents_envs.side_channel.side_channel import SideChannel, IncomingMessage
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel_manager import SideChannelManager
from mlagents_envs import env_utils
from mlagents_envs.base_env import (
BaseEnv,

BehaviorName,
AgentId,
BehaviorMapping,
)
from mlagents_envs.timers import timed, hierarchical_timer
from mlagents_envs.exception import (

from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
from .rpc_communicator import RpcCommunicator
from sys import platform
import struct
SCALAR_ACTION_TYPES = (int, np.int32, np.int64, float, np.float32, np.float64)
SINGLE_BRAIN_ACTION_TYPES = SCALAR_ACTION_TYPES + (list, np.ndarray)
# Communication protocol version.
# When connecting to C#, this must be compatible with Academy.k_ApiVersion.
# We follow semantic versioning on the communication version, so existing

BASE_ENVIRONMENT_PORT = 5005
# Command line argument used to pass the port to the executable environment.
PORT_COMMAND_LINE_ARG = "--mlagents-port"
_PORT_COMMAND_LINE_ARG = "--mlagents-port"
@staticmethod
def _raise_version_exception(unity_com_ver: str) -> None:

)
@staticmethod
def check_communication_compatibility(
def _check_communication_compatibility(
unity_com_ver: str, python_api_version: str, unity_package_version: str
) -> bool:
unity_communicator_version = StrictVersion(unity_com_ver)

return True
@staticmethod
def get_capabilities_proto() -> UnityRLCapabilitiesProto:
def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
def warn_csharp_base_capabitlities(
def _warn_csharp_base_capabilities(
caps: UnityRLCapabilitiesProto, unity_package_ver: str, python_package_ver: str
) -> None:
if not caps.baseRLCapabilities:

:str log_folder: Optional folder to write the Unity Player log file into. Requires absolute path.
"""
atexit.register(self._close)
self.additional_args = additional_args or []
self.no_graphics = no_graphics
self._additional_args = additional_args or []
self._no_graphics = no_graphics
# If base port is not specified, use BASE_ENVIRONMENT_PORT if we have
# an environment, otherwise DEFAULT_EDITOR_PORT
if base_port is None:

self.port = base_port + worker_id
self._port = base_port + worker_id
self.proc1 = None
self.timeout_wait: int = timeout_wait
self.communicator = self.get_communicator(worker_id, base_port, timeout_wait)
self.worker_id = worker_id
self.side_channels: Dict[uuid.UUID, SideChannel] = {}
if side_channels is not None:
for _sc in side_channels:
if _sc.channel_id in self.side_channels:
raise UnityEnvironmentException(
"There cannot be two side channels with the same channel id {0}.".format(
_sc.channel_id
)
)
self.side_channels[_sc.channel_id] = _sc
self.log_folder = log_folder
self._proc1 = None
self._timeout_wait: int = timeout_wait
self._communicator = self._get_communicator(worker_id, base_port, timeout_wait)
self._worker_id = worker_id
self._side_channel_manager = SideChannelManager(side_channels)
self._log_folder = log_folder
# If the environment name is None, a new environment will not be launched
# and the communicator will directly try to connect to an existing unity environment.

"the worker-id must be 0 in order to connect with the Editor."
)
if file_name is not None:
self.executable_launcher(file_name, no_graphics, additional_args)
try:
self._proc1 = env_utils.launch_executable(
file_name, self._executable_args()
)
except UnityEnvironmentException:
self._close(0)
raise
f"Listening on port {self.port}. "
f"Listening on port {self._port}. "
f"Start training by pressing the Play button in the Unity Editor."
)
self._loaded = True

communication_version=self.API_VERSION,
package_version=mlagents_envs.__version__,
capabilities=UnityEnvironment.get_capabilities_proto(),
capabilities=UnityEnvironment._get_capabilities_proto(),
aca_output = self.send_academy_parameters(rl_init_parameters_in)
aca_output = self._send_academy_parameters(rl_init_parameters_in)
if not UnityEnvironment.check_communication_compatibility(
if not UnityEnvironment._check_communication_compatibility(
aca_params.communication_version,
UnityEnvironment.API_VERSION,
aca_params.package_version,

UnityEnvironment.warn_csharp_base_capabitlities(
UnityEnvironment._warn_csharp_base_capabilities(
aca_params.capabilities,
aca_params.package_version,
UnityEnvironment.API_VERSION,

self._update_behavior_specs(aca_output)
@staticmethod
def get_communicator(worker_id, base_port, timeout_wait):
def _get_communicator(worker_id, base_port, timeout_wait):
@staticmethod
def validate_environment_path(env_path: str) -> Optional[str]:
# Strip out executable extensions if passed
env_path = (
env_path.strip()
.replace(".app", "")
.replace(".exe", "")
.replace(".x86_64", "")
.replace(".x86", "")
)
true_filename = os.path.basename(os.path.normpath(env_path))
logger.debug("The true file name is {}".format(true_filename))
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None
cwd = os.getcwd()
launch_string = None
true_filename = os.path.basename(os.path.normpath(env_path))
if platform == "linux" or platform == "linux2":
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86")
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "darwin":
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "win32":
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
if len(candidates) == 0:
candidates = glob.glob(env_path + ".exe")
if len(candidates) > 0:
launch_string = candidates[0]
return launch_string
def executable_args(self) -> List[str]:
def _executable_args(self) -> List[str]:
if self.no_graphics:
if self._no_graphics:
args += [UnityEnvironment.PORT_COMMAND_LINE_ARG, str(self.port)]
if self.log_folder:
args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)]
if self._log_folder:
self.log_folder, f"Player-{self.worker_id}.log"
self._log_folder, f"Player-{self._worker_id}.log"
args += self.additional_args
args += self._additional_args
def executable_launcher(self, file_name, no_graphics, args):
launch_string = self.validate_environment_path(file_name)
if launch_string is None:
self._close(0)
raise UnityEnvironmentException(
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)
else:
logger.debug("This is the launch string {}".format(launch_string))
# Launch Unity environment
subprocess_args = [launch_string] + self.executable_args()
try:
self.proc1 = subprocess.Popen(
subprocess_args,
# start_new_session=True means that signals to the parent python process
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms.
# This is generally good since we want the environment to have a chance to shutdown,
# but may be undesirable in come cases; if so, we'll add a command-line toggle.
# Note that on Windows, the CTRL_C signal will still be sent.
start_new_session=True,
)
except PermissionError as perm:
# This is likely due to missing read or execute permissions on file.
raise UnityEnvironmentException(
f"Error when trying to launch environment - make sure "
f"permissions are set correctly. For example "
f'"chmod -R 755 {launch_string}"'
) from perm
def _update_behavior_specs(self, output: UnityOutputProto) -> None:
init_output = output.rl_initialization_output
for brain_param in init_output.brain_parameters:

DecisionSteps.empty(self._env_specs[brain_name]),
TerminalSteps.empty(self._env_specs[brain_name]),
)
self._parse_side_channel_message(self.side_channels, output.side_channel)
self._side_channel_manager.process_side_channel_message(output.side_channel)
outputs = self.communicator.exchange(self._generate_reset_input())
outputs = self._communicator.exchange(self._generate_reset_input())
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

].create_empty_action(n_agents)
step_input = self._generate_step_input(self._env_actions)
with hierarchical_timer("communicator.exchange"):
outputs = self.communicator.exchange(step_input)
outputs = self._communicator.exchange(step_input)
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

def get_behavior_names(self):
return list(self._env_specs.keys())
@property
def behavior_specs(self) -> MappingType[str, BehaviorSpec]:
return BehaviorMapping(self._env_specs)
def _assert_behavior_exists(self, behavior_name: str) -> None:
if behavior_name not in self._env_specs:

expected_shape = (len(self._env_state[behavior_name][0]), spec.action_size)
if action.shape != expected_shape:
raise UnityActionException(
"The behavior {0} needs an input of dimension {1} but received input of dimension {2}".format(
behavior_name, expected_shape, action.shape
)
"The behavior {0} needs an input of dimension {1} for "
"(<number of agents>, <action size>) but received input of "
"dimension {2}".format(behavior_name, expected_shape, action.shape)
)
if action.dtype != expected_type:
action = action.astype(expected_type)

self._assert_behavior_exists(behavior_name)
return self._env_state[behavior_name]
def get_behavior_spec(self, behavior_name: BehaviorName) -> BehaviorSpec:
self._assert_behavior_exists(behavior_name)
return self._env_specs[behavior_name]
def close(self):
"""
Sends a shutdown signal to the unity environment, and closes the socket connection.

force-killing it. Defaults to `self.timeout_wait`.
"""
if timeout is None:
timeout = self.timeout_wait
timeout = self._timeout_wait
self.communicator.close()
if self.proc1 is not None:
self._communicator.close()
if self._proc1 is not None:
self.proc1.wait(timeout=timeout)
signal_name = self.returncode_to_signal_name(self.proc1.returncode)
self._proc1.wait(timeout=timeout)
signal_name = self._returncode_to_signal_name(self._proc1.returncode)
return_info = f"Environment shut down with return code {self.proc1.returncode}{signal_name}."
return_info = f"Environment shut down with return code {self._proc1.returncode}{signal_name}."
self.proc1.kill()
self._proc1.kill()
self.proc1 = None
@classmethod
def _flatten(cls, arr: Any) -> List[float]:
"""
Converts arrays to list.
:param arr: numpy vector.
:return: flattened list.
"""
if isinstance(arr, cls.SCALAR_ACTION_TYPES):
arr = [float(arr)]
if isinstance(arr, np.ndarray):
arr = arr.tolist()
if len(arr) == 0:
return arr
if isinstance(arr[0], np.ndarray):
# pylint: disable=no-member
arr = [item for sublist in arr for item in sublist.tolist()]
if isinstance(arr[0], list):
# pylint: disable=not-an-iterable
arr = [item for sublist in arr for item in sublist]
arr = [float(x) for x in arr]
return arr
@staticmethod
def _parse_side_channel_message(
side_channels: Dict[uuid.UUID, SideChannel], data: bytes
) -> None:
offset = 0
while offset < len(data):
try:
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
offset += 16
message_len, = struct.unpack_from("<i", data, offset)
offset = offset + 4
message_data = data[offset : offset + message_len]
offset = offset + message_len
except Exception:
raise UnityEnvironmentException(
"There was a problem reading a message in a SideChannel. "
"Please make sure the version of MLAgents in Unity is "
"compatible with the Python version."
)
if len(message_data) != message_len:
raise UnityEnvironmentException(
"The message received by the side channel {0} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_id)
)
if channel_id in side_channels:
incoming_message = IncomingMessage(message_data)
side_channels[channel_id].on_message_received(incoming_message)
else:
logger.warning(
"Unknown side channel data received. Channel type "
": {0}.".format(channel_id)
)
@staticmethod
def _generate_side_channel_data(
side_channels: Dict[uuid.UUID, SideChannel]
) -> bytearray:
result = bytearray()
for channel_id, channel in side_channels.items():
for message in channel.message_queue:
result += channel_id.bytes_le
result += struct.pack("<i", len(message))
result += message
channel.message_queue = []
return result
self._proc1 = None
@timed
def _generate_step_input(

action = AgentActionProto(vector_actions=vector_action[b][i])
rl_in.agent_actions[b].value.extend([action])
rl_in.command = STEP
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
return self.wrap_unity_input(rl_in)
rl_in.side_channel = bytes(
self._side_channel_manager.generate_side_channel_messages()
)
return self._wrap_unity_input(rl_in)
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
return self.wrap_unity_input(rl_in)
rl_in.side_channel = bytes(
self._side_channel_manager.generate_side_channel_messages()
)
return self._wrap_unity_input(rl_in)
def send_academy_parameters(
def _send_academy_parameters(
return self.communicator.initialize(inputs)
return self._communicator.initialize(inputs)
def wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto:
def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto:
def returncode_to_signal_name(returncode: int) -> Optional[str]:
def _returncode_to_signal_name(returncode: int) -> Optional[str]:
"""
Try to convert return codes into their corresponding signal name.
E.g. returncode_to_signal_name(-2) -> "SIGINT"

54
ml-agents-envs/mlagents_envs/tests/test_envs.py


from mlagents_envs.mock_communicator import MockCommunicator
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
assert env.get_behavior_names() == ["RealFakeBrain"]
assert list(env.behavior_specs.keys()) == ["RealFakeBrain"]
env.close()

(None, None, UnityEnvironment.DEFAULT_EDITOR_PORT),
],
)
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_port_defaults(
mock_communicator, mock_launcher, base_port, file_name, expected
):

env = UnityEnvironment(file_name=file_name, worker_id=0, base_port=base_port)
assert expected == env.port
assert expected == env._port
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
args = env.executable_args()
args = env._executable_args()
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
spec = env.get_behavior_spec("RealFakeBrain")
spec = env.behavior_specs["RealFakeBrain"]
env.reset()
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
env.close()

assert (n_agents,) + shape == obs.shape
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
spec = env.get_behavior_spec("RealFakeBrain")
spec = env.behavior_specs["RealFakeBrain"]
env.step()
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
n_agents = len(decision_steps)

assert 2 in terminal_steps
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_close(mock_communicator, mock_launcher):
comm = MockCommunicator(discrete_action=False, visual_inputs=0)
mock_communicator.return_value = comm

unity_ver = "1.0.0"
python_ver = "1.0.0"
unity_package_version = "0.15.0"
assert UnityEnvironment.check_communication_compatibility(
assert UnityEnvironment._check_communication_compatibility(
assert UnityEnvironment.check_communication_compatibility(
assert UnityEnvironment._check_communication_compatibility(
assert not UnityEnvironment.check_communication_compatibility(
assert not UnityEnvironment._check_communication_compatibility(
assert UnityEnvironment.check_communication_compatibility(
assert UnityEnvironment._check_communication_compatibility(
assert not UnityEnvironment.check_communication_compatibility(
assert not UnityEnvironment._check_communication_compatibility(
assert not UnityEnvironment.check_communication_compatibility(
assert not UnityEnvironment._check_communication_compatibility(
assert UnityEnvironment.returncode_to_signal_name(-2) == "SIGINT"
assert UnityEnvironment.returncode_to_signal_name(42) is None
assert UnityEnvironment.returncode_to_signal_name("SIGINT") is None
assert UnityEnvironment._returncode_to_signal_name(-2) == "SIGINT"
assert UnityEnvironment._returncode_to_signal_name(42) is None
assert UnityEnvironment._returncode_to_signal_name("SIGINT") is None
if __name__ == "__main__":

47
ml-agents-envs/mlagents_envs/tests/test_side_channel.py


import uuid
import pytest
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.side_channel.side_channel_manager import SideChannelManager
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
from mlagents_envs.side_channel.raw_bytes_channel import RawBytesChannel
from mlagents_envs.side_channel.engine_configuration_channel import (

StatsSideChannel,
StatsAggregationMethod,
)
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.exception import (
UnitySideChannelException,
UnityCommunicationException,

receiver = IntChannel()
sender.send_int(5)
sender.send_int(6)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
assert receiver.list_int[0] == 5
assert receiver.list_int[1] == 6

sender.set_property("prop1", 1.0)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
val = receiver.get_property("prop1")
assert val == 1.0

data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
val = receiver.get_property("prop1")
assert val == 1.0

sender.send_raw_data("foo".encode("ascii"))
sender.send_raw_data("bar".encode("ascii"))
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
messages = receiver.get_and_clear_received_messages()
assert len(messages) == 2

config = EngineConfig.default_config()
sender.set_configuration(config)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
received_data = receiver.get_and_clear_received_messages()
assert len(received_data) == 5 # 5 different messages one for each setting

data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
message = IncomingMessage(receiver.get_and_clear_received_messages()[0])
message.read_int32()

with pytest.raises(UnityCommunicationException):
# try to send data to the EngineConfigurationChannel
sender.set_configuration_parameters(time_scale=sent_time_scale)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_id: sender}, data
)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([sender]).process_side_channel_message(data)
def test_environment_parameters():

sender.set_float_parameter("param-1", 0.1)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
message = IncomingMessage(receiver.get_and_clear_received_messages()[0])
key = message.read_string()

sender.set_float_parameter("param-2", 0.1)
sender.set_float_parameter("param-3", 0.1)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
assert len(receiver.get_and_clear_received_messages()) == 3

data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_id: sender}, data
)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([sender]).process_side_channel_message(data)
def test_stats_channel():

14
ml-agents/mlagents/trainers/learn.py


from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.timers import (
hierarchical_timer,
get_timer_tree,

os.path.join(base_path, options.run_id) if options.initialize_from else None
)
run_logs_dir = os.path.join(write_path, "run_logs")
port = options.base_port
port: Optional[int] = options.base_port
# Check if directory exists
handle_existing_directories(
write_path, options.resume, options.force, maybe_init_path

StatsReporter.add_writer(console_writer)
if options.env_path is None:
port = UnityEnvironment.DEFAULT_EDITOR_PORT
port = None
env_factory = create_environment_factory(
options.env_path,
options.no_graphics,

env_path: Optional[str],
no_graphics: bool,
seed: int,
start_port: int,
start_port: Optional[int],
if env_path is not None:
launch_string = UnityEnvironment.validate_environment_path(env_path)
if launch_string is None:
raise UnityEnvironmentException(
f"Couldn't launch the {env_path} environment. Provided filename does not match any environments."
)
def create_unity_environment(
worker_id: int, side_channels: List[SideChannel]
) -> UnityEnvironment:

1
ml-agents/mlagents/trainers/policy/tf_policy.py


)
)
if reset_global_steps:
self._set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path

4
ml-agents/mlagents/trainers/ppo/trainer.py


self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer
max_num_batch = buffer_length // batch_size
for l in range(0, max_num_batch * batch_size, batch_size):
for i in range(0, max_num_batch * batch_size, batch_size):
buffer.make_mini_batch(l, l + batch_size), n_sequences
buffer.make_mini_batch(i, i + batch_size), n_sequences
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)

8
ml-agents/mlagents/trainers/simple_env_manager.py


@property
def external_brains(self) -> Dict[BehaviorName, BrainParameters]:
result = {}
for brain_name in self.env.get_behavior_names():
result[brain_name] = behavior_spec_to_brain_parameters(
brain_name, self.env.get_behavior_spec(brain_name)
for behavior_name, behavior_spec in self.env.behavior_specs.items():
result[behavior_name] = behavior_spec_to_brain_parameters(
behavior_name, behavior_spec
)
return result

def _generate_all_results(self) -> AllStepResult:
all_step_result: AllStepResult = {}
for brain_name in self.env.get_behavior_names():
for brain_name in self.env.behavior_specs:
all_step_result[brain_name] = self.env.get_steps(brain_name)
return all_step_result

10
ml-agents/mlagents/trainers/subprocess_env_manager.py


def _generate_all_results() -> AllStepResult:
all_step_result: AllStepResult = {}
for brain_name in env.get_behavior_names():
for brain_name in env.behavior_specs:
for brain_name in env.get_behavior_names():
result[brain_name] = behavior_spec_to_brain_parameters(
brain_name, env.get_behavior_spec(brain_name)
for behavior_name, behavior_specs in env.behavior_specs.items():
result[behavior_name] = behavior_spec_to_brain_parameters(
behavior_name, behavior_specs
)
return result

return self.env_workers[0].recv().payload
def close(self) -> None:
logger.debug(f"SubprocessEnvManager closing.")
logger.debug("SubprocessEnvManager closing.")
self.step_queue.close()
self.step_queue.join_thread()
for env_worker in self.env_workers:

12
ml-agents/mlagents/trainers/tests/simple_test_envs.py


DecisionSteps,
TerminalSteps,
ActionType,
BehaviorMapping,
)
from mlagents_envs.tests.test_rpc_utils import proto_from_steps_and_action
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (

obs.append(np.ones((1,) + self.vis_obs_size, dtype=np.float32) * value)
return obs
def get_behavior_names(self):
return self.names
def get_behavior_spec(self, behavior_name):
return self.behavior_spec
@property
def behavior_specs(self):
behavior_dict = {}
for n in self.names:
behavior_dict[n] = self.behavior_spec
return BehaviorMapping(behavior_dict)
def set_action_for_agent(self, behavior_name, agent_id, action):
pass

5
ml-agents/mlagents/trainers/tests/test_learn.py


def test_bad_env_path():
with pytest.raises(UnityEnvironmentException):
learn.create_environment_factory(
factory = learn.create_environment_factory(
seed=None,
seed=-1,
factory(worker_id=-1, side_channels=[])
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_YAML)

4
ml-agents/mlagents/trainers/tests/test_nn_policy.py


trainer_params["output_path"] = path1
policy = create_policy_mock(trainer_params)
policy.initialize_or_load()
policy._set_step(2000)
policy.save_model(2000)
assert len(os.listdir(tmp_path)) > 0

policy2.initialize_or_load()
_compare_two_policies(policy, policy2)
assert policy2.get_current_step() == 2000
# Try initialize from path 1
trainer_params["model_path"] = path2

_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0
def _compare_two_policies(policy1: NNPolicy, policy2: NNPolicy) -> None:

6
ml-agents/tests/yamato/check_coverage_percent.py


# Rather than try to parse the XML, just look for a line of the form
# <Linecoverage>73.9</Linecoverage>
lines = f.readlines()
for l in lines:
if "Linecoverage" in l:
pct = l.replace("<Linecoverage>", "").replace("</Linecoverage>", "")
for line in lines:
if "Linecoverage" in line:
pct = line.replace("<Linecoverage>", "").replace("</Linecoverage>", "")
pct = float(pct)
if pct < min_percentage:
print(

4
ml-agents/tests/yamato/scripts/run_llapi.py


env.reset()
# Set the default brain to work with
group_name = env.get_behavior_names()[0]
group_spec = env.get_behavior_spec(group_name)
group_name = list(env.behavior_specs.keys())[0]
group_spec = env.behavior_specs[group_name]
# Set the time scale of the engine
engine_configuration_channel.set_configuration_parameters(time_scale=3.0)

2
ml-agents/tests/yamato/yamato_utils.py


subprocess.check_call("git reset HEAD .", shell=True)
subprocess.check_call("git checkout -- .", shell=True)
# Ensure the cache isn't polluted with old compiled assemblies.
subprocess.check_call(f"rm -rf Project/Library", shell=True)
subprocess.check_call("rm -rf Project/Library", shell=True)
def override_config_file(src_path, dest_path, **kwargs):

6
utils/validate_versions.py


def extract_version_string(filename):
with open(filename) as f:
for l in f.readlines():
if l.startswith(VERSION_LINE_START):
return l.replace(VERSION_LINE_START, "").strip()
for line in f.readlines():
if line.startswith(VERSION_LINE_START):
return line.replace(VERSION_LINE_START, "").strip()
return None

8
com.unity.ml-agents/Runtime/Sensors/Reflection.meta


fileFormatVersion: 2
guid: 08ece3d7e9bb94089a9d59c6f269ab0a
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta


fileFormatVersion: 2
guid: e5e4df2934c014aa3b835b9eb9ad20b3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

292
com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs


using System;
using System.Collections.Generic;
using NUnit.Framework;
using UnityEngine;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class ObservableAttributeTests
{
class TestClass
{
// Non-observables
int m_NonObservableInt;
float m_NonObservableFloat;
//
// Int
//
[Observable]
public int m_IntMember;
int m_IntProperty;
[Observable]
public int IntProperty
{
get => m_IntProperty;
set => m_IntProperty = value;
}
//
// Float
//
[Observable("floatMember")]
public float m_FloatMember;
float m_FloatProperty;
[Observable("floatProperty")]
public float FloatProperty
{
get => m_FloatProperty;
set => m_FloatProperty = value;
}
//
// Bool
//
[Observable("boolMember")]
public bool m_BoolMember;
bool m_BoolProperty;
[Observable("boolProperty")]
public bool BoolProperty
{
get => m_BoolProperty;
set => m_BoolProperty = value;
}
//
// Vector2
//
[Observable("vector2Member")]
public Vector2 m_Vector2Member;
Vector2 m_Vector2Property;
[Observable("vector2Property")]
public Vector2 Vector2Property
{
get => m_Vector2Property;
set => m_Vector2Property = value;
}
//
// Vector3
//
[Observable("vector3Member")]
public Vector3 m_Vector3Member;
Vector3 m_Vector3Property;
[Observable("vector3Property")]
public Vector3 Vector3Property
{
get => m_Vector3Property;
set => m_Vector3Property = value;
}
//
// Vector4
//
[Observable("vector4Member")]
public Vector4 m_Vector4Member;
Vector4 m_Vector4Property;
[Observable("vector4Property")]
public Vector4 Vector4Property
{
get => m_Vector4Property;
set => m_Vector4Property = value;
}
//
// Quaternion
//
[Observable("quaternionMember")]
public Quaternion m_QuaternionMember;
Quaternion m_QuaternionProperty;
[Observable("quaternionProperty")]
public Quaternion QuaternionProperty
{
get => m_QuaternionProperty;
set => m_QuaternionProperty = value;
}
}
[Test]
public void TestGetObservableSensors()
{
var testClass = new TestClass();
testClass.m_IntMember = 1;
testClass.IntProperty = 2;
testClass.m_FloatMember = 1.1f;
testClass.FloatProperty = 1.2f;
testClass.m_BoolMember = true;
testClass.BoolProperty = true;
testClass.m_Vector2Member = new Vector2(2.0f, 2.1f);
testClass.Vector2Property = new Vector2(2.2f, 2.3f);
testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f);
testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f);
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f);
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f);
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f);
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f);
testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f);
testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f);
var sensors = ObservableAttribute.CreateObservableSensors(testClass, false);
var sensorsByName = new Dictionary<string, ISensor>();
foreach (var sensor in sensors)
{
sensorsByName[sensor.GetName()] = sensor;
}
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f });
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f });
SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f });
SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f });
SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f });
SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f });
SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f });
SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f });
SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f });
SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f });
SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f });
SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f });
SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f });
SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f });
}
[Test]
public void TestGetTotalObservationSize()
{
var testClass = new TestClass();
var errors = new List<string>();
var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4);
Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors));
Assert.AreEqual(0, errors.Count);
}
class BadClass
{
[Observable]
double m_Double;
[Observable]
double DoubleProperty
{
get => m_Double;
set => m_Double = value;
}
float m_WriteOnlyProperty;
[Observable]
// No get property, so we shouldn't be able to make a sensor out of this.
public float WriteOnlyProperty
{
set => m_WriteOnlyProperty = value;
}
}
[Test]
public void TestInvalidObservables()
{
var bad = new BadClass();
bad.WriteOnlyProperty = 1.0f;
var errors = new List<string>();
Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors));
Assert.AreEqual(3, errors.Count);
// Should be able to safely generate sensors (and get nothing back)
var sensors = ObservableAttribute.CreateObservableSensors(bad, false);
Assert.AreEqual(0, sensors.Count);
}
class StackingClass
{
[Observable(numStackedObservations: 2)]
public float FloatVal;
}
[Test]
public void TestObservableAttributeStacking()
{
var c = new StackingClass();
c.FloatVal = 1.0f;
var sensors = ObservableAttribute.CreateObservableSensors(c, false);
var sensor = sensors[0];
Assert.AreEqual(typeof(StackingSensor), sensor.GetType());
SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f });
sensor.Update();
c.FloatVal = 3.0f;
SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f });
var errors = new List<string>();
Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors));
Assert.AreEqual(0, errors.Count);
}
class BaseClass
{
[Observable("base")]
public float m_BaseField;
[Observable("private")]
float m_PrivateField;
}
class DerivedClass : BaseClass
{
[Observable("derived")]
float m_DerivedField;
}
[Test]
public void TestObservableAttributeExcludeInherited()
{
var d = new DerivedClass();
d.m_BaseField = 1.0f;
// excludeInherited=false will get fields in the derived class, plus public and protected inherited fields
var sensorAll = ObservableAttribute.CreateObservableSensors(d, false);
Assert.AreEqual(2, sensorAll.Count);
// Note - actual order doesn't matter here, we can change this to use a HashSet if neeed.
Assert.AreEqual("derived", sensorAll[0].GetName());
Assert.AreEqual("base", sensorAll[1].GetName());
// excludeInherited=true will only get fields in the derived class
var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true);
Assert.AreEqual(1, sensorsDerivedOnly.Count);
Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName());
var b = new BaseClass();
var baseSensors = ObservableAttribute.CreateObservableSensors(b, false);
Assert.AreEqual(2, baseSensors.Count);
}
}
}

11
com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta


fileFormatVersion: 2
guid: 33d7912e6b3504412bd261b40e46df32
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

108
ml-agents-envs/mlagents_envs/env_utils.py


import glob
import os
import subprocess
from sys import platform
from typing import Optional, List
from mlagents_envs.logging_util import get_logger
from mlagents_envs.exception import UnityEnvironmentException
def get_platform():
"""
returns the platform of the operating system : linux, darwin or win32
"""
return platform
def validate_environment_path(env_path: str) -> Optional[str]:
"""
Strip out executable extensions of the env_path
:param env_path: The path to the executable
"""
env_path = (
env_path.strip()
.replace(".app", "")
.replace(".exe", "")
.replace(".x86_64", "")
.replace(".x86", "")
)
true_filename = os.path.basename(os.path.normpath(env_path))
get_logger(__name__).debug("The true file name is {}".format(true_filename))
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None
cwd = os.getcwd()
launch_string = None
true_filename = os.path.basename(os.path.normpath(env_path))
if get_platform() == "linux" or get_platform() == "linux2":
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86")
if len(candidates) > 0:
launch_string = candidates[0]
elif get_platform() == "darwin":
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) > 0:
launch_string = candidates[0]
elif get_platform() == "win32":
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
if len(candidates) == 0:
candidates = glob.glob(env_path + ".exe")
if len(candidates) > 0:
launch_string = candidates[0]
return launch_string
def launch_executable(file_name: str, args: List[str]) -> subprocess.Popen:
"""
Launches a Unity executable and returns the process handle for it.
:param file_name: the name of the executable
:param args: List of string that will be passed as command line arguments
when launching the executable.
"""
launch_string = validate_environment_path(file_name)
if launch_string is None:
raise UnityEnvironmentException(
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)
else:
get_logger(__name__).debug("This is the launch string {}".format(launch_string))
# Launch Unity environment
subprocess_args = [launch_string] + args
try:
return subprocess.Popen(
subprocess_args,
# start_new_session=True means that signals to the parent python process
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms.
# This is generally good since we want the environment to have a chance to shutdown,
# but may be undesirable in come cases; if so, we'll add a command-line toggle.
# Note that on Windows, the CTRL_C signal will still be sent.
start_new_session=True,
)
except PermissionError as perm:
# This is likely due to missing read or execute permissions on file.
raise UnityEnvironmentException(
f"Error when trying to launch environment - make sure "
f"permissions are set correctly. For example "
f'"chmod -R 755 {launch_string}"'
) from perm

81
ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py


import uuid
import struct
from typing import Dict, Optional, List
from mlagents_envs.side_channel import SideChannel, IncomingMessage
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.logging_util import get_logger
class SideChannelManager:
def __init__(self, side_channels=Optional[List[SideChannel]]):
self._side_channels_dict = self._get_side_channels_dict(side_channels)
def process_side_channel_message(self, data: bytes) -> None:
"""
Separates the data received from Python into individual messages for each
registered side channel and calls on_message_received on them.
:param data: The packed message sent by Unity
"""
offset = 0
while offset < len(data):
try:
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
offset += 16
message_len, = struct.unpack_from("<i", data, offset)
offset = offset + 4
message_data = data[offset : offset + message_len]
offset = offset + message_len
except (struct.error, ValueError, IndexError):
raise UnityEnvironmentException(
"There was a problem reading a message in a SideChannel. "
"Please make sure the version of MLAgents in Unity is "
"compatible with the Python version."
)
if len(message_data) != message_len:
raise UnityEnvironmentException(
"The message received by the side channel {0} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_id)
)
if channel_id in self._side_channels_dict:
incoming_message = IncomingMessage(message_data)
self._side_channels_dict[channel_id].on_message_received(
incoming_message
)
else:
get_logger(__name__).warning(
f"Unknown side channel data received. Channel type: {channel_id}."
)
def generate_side_channel_messages(self) -> bytearray:
"""
Gathers the messages that the registered side channels will send to Unity
and combines them into a single message ready to be sent.
"""
result = bytearray()
for channel_id, channel in self._side_channels_dict.items():
for message in channel.message_queue:
result += channel_id.bytes_le
result += struct.pack("<i", len(message))
result += message
channel.message_queue = []
return result
@staticmethod
def _get_side_channels_dict(
side_channels: Optional[List[SideChannel]]
) -> Dict[uuid.UUID, SideChannel]:
"""
Converts a list of side channels into a dictionary of channel_id to SideChannel
:param side_channels: The list of side channels.
"""
side_channels_dict: Dict[uuid.UUID, SideChannel] = {}
if side_channels is not None:
for _sc in side_channels:
if _sc.channel_id in side_channels_dict:
raise UnityEnvironmentException(
f"There cannot be two side channels with "
f"the same channel id {_sc.channel_id}."
)
side_channels_dict[_sc.channel_id] = _sc
return side_channels_dict

64
ml-agents-envs/mlagents_envs/tests/test_env_utils.py


from unittest import mock
import pytest
from mlagents_envs.env_utils import validate_environment_path, launch_executable
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.logging_util import (
set_log_level,
get_logger,
INFO,
ERROR,
FATAL,
CRITICAL,
DEBUG,
)
def mock_glob_method(path):
"""
Given a path input, returns a list of candidates
"""
if ".x86" in path:
return ["linux"]
if ".app" in path:
return ["darwin"]
if ".exe" in path:
return ["win32"]
if "*" in path:
return "Any"
return []
@mock.patch("sys.platform")
@mock.patch("glob.glob")
def test_validate_path_empty(glob_mock, platform_mock):
glob_mock.return_value = None
path = validate_environment_path(" ")
assert path is None
@mock.patch("mlagents_envs.env_utils.get_platform")
@mock.patch("glob.glob")
def test_validate_path(glob_mock, platform_mock):
glob_mock.side_effect = mock_glob_method
for platform in ["linux", "darwin", "win32"]:
platform_mock.return_value = platform
path = validate_environment_path(" ")
assert path == platform
@mock.patch("glob.glob")
@mock.patch("subprocess.Popen")
def test_launch_executable(mock_popen, glob_mock):
with pytest.raises(UnityEnvironmentException):
launch_executable(" ", [])
glob_mock.return_value = ["FakeLaunchPath"]
launch_executable(" ", [])
mock_popen.side_effect = PermissionError("Fake permission error")
with pytest.raises(UnityEnvironmentException):
launch_executable(" ", [])
def test_set_logging_level():
for level in [INFO, ERROR, FATAL, CRITICAL, DEBUG]:
set_log_level(level)
assert get_logger("test").level == level

102
ml-agents-envs/mlagents_envs/tests/test_steps.py


import pytest
import numpy as np
from mlagents_envs.base_env import (
DecisionSteps,
TerminalSteps,
ActionType,
BehaviorSpec,
)
def test_decision_steps():
ds = DecisionSteps(
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)],
reward=np.array(range(3), dtype=np.float32),
agent_id=np.array(range(10, 13), dtype=np.int32),
action_mask=[np.zeros((3, 4), dtype=np.bool)],
)
assert ds.agent_id_to_index[10] == 0
assert ds.agent_id_to_index[11] == 1
assert ds.agent_id_to_index[12] == 2
with pytest.raises(KeyError):
assert ds.agent_id_to_index[-1] == -1
mask_agent = ds[10].action_mask
assert isinstance(mask_agent, list)
assert len(mask_agent) == 1
assert np.array_equal(mask_agent[0], np.zeros((4), dtype=np.bool))
for agent_id in ds:
assert ds.agent_id_to_index[agent_id] in range(3)
def test_empty_decision_steps():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
)
ds = DecisionSteps.empty(specs)
assert len(ds.obs) == 2
assert ds.obs[0].shape == (0, 3, 2)
assert ds.obs[1].shape == (0, 5)
def test_terminal_steps():
ts = TerminalSteps(
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)],
reward=np.array(range(3), dtype=np.float32),
agent_id=np.array(range(10, 13), dtype=np.int32),
interrupted=np.array([1, 0, 1], dtype=np.bool),
)
assert ts.agent_id_to_index[10] == 0
assert ts.agent_id_to_index[11] == 1
assert ts.agent_id_to_index[12] == 2
assert ts[10].interrupted
assert not ts[11].interrupted
assert ts[12].interrupted
with pytest.raises(KeyError):
assert ts.agent_id_to_index[-1] == -1
for agent_id in ts:
assert ts.agent_id_to_index[agent_id] in range(3)
def test_empty_terminal_steps():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
)
ts = TerminalSteps.empty(specs)
assert len(ts.obs) == 2
assert ts.obs[0].shape == (0, 3, 2)
assert ts.obs[1].shape == (0, 5)
def test_specs():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
)
assert specs.discrete_action_branches is None
assert specs.action_size == 3
assert specs.create_empty_action(5).shape == (5, 3)
assert specs.create_empty_action(5).dtype == np.float32
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.DISCRETE,
action_shape=(3,),
)
assert specs.discrete_action_branches == (3,)
assert specs.action_size == 1
assert specs.create_empty_action(5).shape == (5, 1)
assert specs.create_empty_action(5).dtype == np.int32

11
com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta


fileFormatVersion: 2
guid: be795c90750a6420d93f569b69ddc1ba
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta


fileFormatVersion: 2
guid: 51ed837d5b7cd44349287ac8066120fc
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta


fileFormatVersion: 2
guid: 5cae4c843cc074d11a549aaa3904c898
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta


fileFormatVersion: 2
guid: a75086dc66a594baea6b8b2935f5dacf
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta


fileFormatVersion: 2
guid: d38241d74074d459bb4590f7f5d16c80
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

97
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


using System.Reflection;
namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Construction info for a ReflectionSensorBase.
/// </summary>
internal struct ReflectionSensorInfo
{
public object Object;
public FieldInfo FieldInfo;
public PropertyInfo PropertyInfo;
public ObservableAttribute ObservableAttribute;
public string SensorName;
}
/// <summary>
/// Abstract base class for reflection-based sensors.
/// </summary>
internal abstract class ReflectionSensorBase : ISensor
{
protected object m_Object;
// Exactly one of m_FieldInfo and m_PropertyInfo should be non-null.
protected FieldInfo m_FieldInfo;
protected PropertyInfo m_PropertyInfo;
// Not currently used, but might want later.
protected ObservableAttribute m_ObservableAttribute;
// Cached sensor names and shapes.
string m_SensorName;
int[] m_Shape;
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size)
{
m_Object = reflectionSensorInfo.Object;
m_FieldInfo = reflectionSensorInfo.FieldInfo;
m_PropertyInfo = reflectionSensorInfo.PropertyInfo;
m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_Shape = new[] {size};
}
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
WriteReflectedField(writer);
return m_Shape[0];
}
internal abstract void WriteReflectedField(ObservationWriter writer);
/// <summary>
/// Get either the reflected field, or return the reflected property.
/// This should be used by implementations in their WriteReflectedField() method.
/// </summary>
/// <returns></returns>
protected object GetReflectedValue()
{
return m_FieldInfo != null ?
m_FieldInfo.GetValue(m_Object) :
m_PropertyInfo.GetMethod.Invoke(m_Object, null);
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
return null;
}
/// <inheritdoc/>
public void Update() {}
/// <inheritdoc/>
public void Reset() {}
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
/// <inheritdoc/>
public string GetName()
{
return m_SensorName;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta


fileFormatVersion: 2
guid: 6b68d855fb94a45fbbeb0dbe968a35f8
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta


fileFormatVersion: 2
guid: da06ff33f6f2d409cbf240cffa2ba0be
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta


fileFormatVersion: 2
guid: e756976ec2a0943cfbc0f97a6550a85b
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta


fileFormatVersion: 2
guid: 01d93aaa1b42b47b8960d303d7c498d3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

19
com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a boolean field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class BoolReflectionSensor : ReflectionSensorBase
{
public BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var boolVal = (System.Boolean)GetReflectedValue();
writer[0] = boolVal ? 1.0f : 0.0f;
}
}
}

19
com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a float field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class FloatReflectionSensor : ReflectionSensorBase
{
public FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var floatVal = (System.Single)GetReflectedValue();
writer[0] = floatVal;
}
}
}

19
com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps an integer field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class IntReflectionSensor : ReflectionSensorBase
{
public IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var intVal = (System.Int32)GetReflectedValue();
writer[0] = (float)intVal;
}
}
}

272
com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs


using System;
using System.Collections.Generic;
using System.Reflection;
using UnityEngine;
namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Specify that a field or property should be used to generate observations for an Agent.
/// For each field or property that uses ObservableAttribute, a corresponding
/// <see cref="ISensor"/> will be created during Agent initialization, and this
/// sensor will read the values during training and inference.
/// </summary>
/// <remarks>
/// ObservableAttribute is intended to make initial setup of an Agent easier. Because it
/// uses reflection to read the values of fields and properties at runtime, this may
/// be much slower than reading the values directly. If the performance of
/// ObservableAttribute is an issue, you can get the same functionality by overriding
/// <see cref="Agent.CollectObservations(VectorSensor)"/> or creating a custom
/// <see cref="ISensor"/> implementation to read the values without reflection.
///
/// Note that you do not need to adjust the VectorObservationSize in
/// <see cref="Unity.MLAgents.Policies.BrainParameters"/> when adding ObservableAttribute
/// to fields or properties.
/// </remarks>
/// <example>
/// This sample class will produce two observations, one for the m_Health field, and one
/// for the HealthPercent property.
/// <code>
/// using Unity.MLAgents;
/// using Unity.MLAgents.Sensors.Reflection;
///
/// public class MyAgent : Agent
/// {
/// [Observable]
/// int m_Health;
///
/// [Observable]
/// float HealthPercent
/// {
/// get => return 100.0f * m_Health / float(m_MaxHealth);
/// }
/// }
/// </code>
/// </example>
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
public class ObservableAttribute : Attribute
{
string m_Name;
int m_NumStackedObservations;
/// <summary>
/// Default binding flags used for reflection of members and properties.
/// </summary>
const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
/// <summary>
/// Supported types and their observation sizes and corresponding sensor type.
/// </summary>
static Dictionary<Type, (int, Type)> s_TypeToSensorInfo = new Dictionary<Type, (int, Type)>()
{
{typeof(int), (1, typeof(IntReflectionSensor))},
{typeof(bool), (1, typeof(BoolReflectionSensor))},
{typeof(float), (1, typeof(FloatReflectionSensor))},
{typeof(Vector2), (2, typeof(Vector2ReflectionSensor))},
{typeof(Vector3), (3, typeof(Vector3ReflectionSensor))},
{typeof(Vector4), (4, typeof(Vector4ReflectionSensor))},
{typeof(Quaternion), (4, typeof(QuaternionReflectionSensor))},
};
/// <summary>
/// ObservableAttribute constructor.
/// </summary>
/// <param name="name">Optional override for the sensor name. Note that all sensors for an Agent
/// must have a unique name.</param>
/// <param name="numStackedObservations">Number of frames to concatenate observations from.</param>
public ObservableAttribute(string name = null, int numStackedObservations = 1)
{
m_Name = name;
m_NumStackedObservations = numStackedObservations;
}
/// <summary>
/// Returns a FieldInfo for all fields that have an ObservableAttribute
/// </summary>
/// <param name="o">Object being reflected</param>
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
/// <returns></returns>
static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited)
{
// TODO cache these (and properties) by type, so that we only have to reflect once.
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0);
var fields = o.GetType().GetFields(bindingFlags);
foreach (var field in fields)
{
var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute));
if (attr != null)
{
yield return (field, attr);
}
}
}
/// <summary>
/// Returns a PropertyInfo for all fields that have an ObservableAttribute
/// </summary>
/// <param name="o">Object being reflected</param>
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
/// <returns></returns>
static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited)
{
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0);
var properties = o.GetType().GetProperties(bindingFlags);
foreach (var prop in properties)
{
var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute));
if (attr != null)
{
yield return (prop, attr);
}
}
}
/// <summary>
/// Creates sensors for each field and property with ObservableAttribute.
/// </summary>
/// <param name="o">Object being reflected</param>
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
/// <returns></returns>
internal static List<ISensor> CreateObservableSensors(object o, bool excludeInherited)
{
var sensorsOut = new List<ISensor>();
foreach (var(field, attr) in GetObservableFields(o, excludeInherited))
{
var sensor = CreateReflectionSensor(o, field, null, attr);
if (sensor != null)
{
sensorsOut.Add(sensor);
}
}
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited))
{
if (!prop.CanRead)
{
// Skip unreadable properties.
continue;
}
var sensor = CreateReflectionSensor(o, null, prop, attr);
if (sensor != null)
{
sensorsOut.Add(sensor);
}
}
return sensorsOut;
}
/// <summary>
/// Create the ISensor for either the field or property on the provided object.
/// If the data type is unsupported, or the property is write-only, returns null.
/// </summary>
/// <param name="o"></param>
/// <param name="fieldInfo"></param>
/// <param name="propertyInfo"></param>
/// <param name="observableAttribute"></param>
/// <returns></returns>
/// <exception cref="UnityAgentsException"></exception>
static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute)
{
string memberName;
string declaringTypeName;
Type memberType;
if (fieldInfo != null)
{
declaringTypeName = fieldInfo.DeclaringType.Name;
memberName = fieldInfo.Name;
memberType = fieldInfo.FieldType;
}
else
{
declaringTypeName = propertyInfo.DeclaringType.Name;
memberName = propertyInfo.Name;
memberType = propertyInfo.PropertyType;
}
if (!s_TypeToSensorInfo.ContainsKey(memberType))
{
// For unsupported types, return null and we'll filter them out later.
return null;
}
string sensorName;
if (string.IsNullOrEmpty(observableAttribute.m_Name))
{
sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}";
}
else
{
sensorName = observableAttribute.m_Name;
}
var reflectionSensorInfo = new ReflectionSensorInfo
{
Object = o,
FieldInfo = fieldInfo,
PropertyInfo = propertyInfo,
ObservableAttribute = observableAttribute,
SensorName = sensorName
};
var (_, sensorType) = s_TypeToSensorInfo[memberType];
var sensor = (ISensor) Activator.CreateInstance(sensorType, reflectionSensorInfo);
// Wrap the base sensor in a StackingSensor if we're using stacking.
if (observableAttribute.m_NumStackedObservations > 1)
{
return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations);
}
return sensor;
}
/// <summary>
/// Gets the sum of the observation sizes of the Observable fields and properties on an object.
/// Also appends errors to the errorsOut array.
/// </summary>
/// <param name="o"></param>
/// <param name="excludeInherited"></param>
/// <param name="errorsOut"></param>
/// <returns></returns>
internal static int GetTotalObservationSize(object o, bool excludeInherited, List<string> errorsOut)
{
int sizeOut = 0;
foreach (var(field, attr) in GetObservableFields(o, excludeInherited))
{
if (s_TypeToSensorInfo.ContainsKey(field.FieldType))
{
var (obsSize, _) = s_TypeToSensorInfo[field.FieldType];
sizeOut += obsSize * attr.m_NumStackedObservations;
}
else
{
errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}");
}
}
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited))
{
if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType))
{
if (prop.CanRead)
{
var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType];
sizeOut += obsSize * attr.m_NumStackedObservations;
}
else
{
errorsOut.Add($"Observable property {prop.Name} is write-only.");
}
}
else
{
errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}");
}
}
return sizeOut;
}
}
}

22
com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a quaternion field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class QuaternionReflectionSensor : ReflectionSensorBase
{
public QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 4)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var quatVal = (UnityEngine.Quaternion)GetReflectedValue();
writer[0] = quatVal.x;
writer[1] = quatVal.y;
writer[2] = quatVal.z;
writer[3] = quatVal.w;
}
}
}

20
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a Vector2 field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class Vector2ReflectionSensor : ReflectionSensorBase
{
public Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 2)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var vecVal = (UnityEngine.Vector2)GetReflectedValue();
writer[0] = vecVal.x;
writer[1] = vecVal.y;
}
}
}

21
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a Vector3 field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class Vector3ReflectionSensor : ReflectionSensorBase
{
public Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 3)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var vecVal = (UnityEngine.Vector3)GetReflectedValue();
writer[0] = vecVal.x;
writer[1] = vecVal.y;
writer[2] = vecVal.z;
}
}
}

22
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a Vector4 field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class Vector4ReflectionSensor : ReflectionSensorBase
{
public Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 4)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var vecVal = (UnityEngine.Vector4)GetReflectedValue();
writer[0] = vecVal.x;
writer[1] = vecVal.y;
writer[2] = vecVal.z;
writer[3] = vecVal.w;
}
}
}
正在加载...
取消
保存