浏览代码

Merge 'master' into develop-removeactionholder

/develop/removeactionholder-onehot
Ervin Teng 4 年前
当前提交
d57124b4
共有 39 个文件被更改,包括 1114 次插入659 次删除
  1. 6
      com.unity.ml-agents/CHANGELOG.md
  2. 2
      com.unity.ml-agents/Editor/DemonstrationImporter.cs
  3. 29
      com.unity.ml-agents/Runtime/Academy.cs
  4. 135
      com.unity.ml-agents/Runtime/Agent.cs
  5. 2
      com.unity.ml-agents/Runtime/InferenceBrain/TensorProxy.cs
  6. 4
      com.unity.ml-agents/Runtime/Sensor/ISensor.cs
  7. 534
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs
  8. 4
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs
  9. 4
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs
  10. 38
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs
  11. 6
      com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs
  12. 50
      com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs
  13. 2
      config/gail_config.yaml
  14. 9
      docs/API-Reference.md
  15. 4
      docs/Migrating.md
  16. 2
      docs/Training-Imitation-Learning.md
  17. 133
      docs/dox-ml-agents.conf
  18. 145
      gym-unity/gym_unity/envs/__init__.py
  19. 77
      gym-unity/gym_unity/tests/test_gym.py
  20. 62
      ml-agents-envs/mlagents_envs/environment.py
  21. 82
      ml-agents/mlagents/trainers/common/nn_policy.py
  22. 6
      ml-agents/mlagents/trainers/common/tf_optimizer.py
  23. 14
      ml-agents/mlagents/trainers/learn.py
  24. 2
      ml-agents/mlagents/trainers/sac/trainer.py
  25. 13
      ml-agents/mlagents/trainers/tests/test_learn.py
  26. 2
      com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs.meta
  27. 2
      com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta
  28. 14
      com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta
  29. 4
      com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs
  30. 104
      com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs
  31. 8
      com.unity.ml-agents/Runtime/Demonstrations.meta
  32. 179
      com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs
  33. 95
      com.unity.ml-agents/Runtime/DemonstrationRecorder.cs
  34. 0
      /com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs.meta
  35. 0
      /com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta
  36. 0
      /com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta
  37. 0
      /com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs
  38. 0
      /com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs

6
com.unity.ml-agents/CHANGELOG.md


- Agent.CollectObservations now takes a VectorSensor argument. It was also overloaded to optionally take an ActionMasker argument. (#3352, #3389)
- Beta support for ONNX export was added. If the `tf2onnx` python package is installed, models will be saved to `.onnx` as well as `.nn` format.
Note that Barracuda 0.6.0 or later is required to import the `.onnx` files properly
- Multi-GPU training and the `--multi-gpu` option has been removed temporarily. (#3345)
### Minor Changes
- Monitor.cs was moved to Examples. (#3372)

- A tutorial on adding custom SideChannels was added (#3391)
- The stepping logic for the Agent and the Academy has been simplified (#3448)
- Update Barracuda to 0.6.0-preview
- The interface for `RayPerceptionSensor.PerceiveStatic()` was changed to take an input class and write to an output class.
- The command-line argument used to determine the port that an environment will listen on was changed from `--port` to `--mlagents-port`.
- `DemonstrationRecorder` can now record observations outside of the editor.
- `DemonstrationRecorder` now has an optional path for the demonstrations. This will default to `Application.dataPath` if not set.
- `DemonstrationStore` was changed to accept a `Stream` for its constructor, and was renamed to `DemonstrationWriter`
- The method `GetStepCount()` on the Agent class has been replaced with the property getter `StepCount`
### Bugfixes

2
com.unity.ml-agents/Editor/DemonstrationImporter.cs


var metaDataProto = DemonstrationMetaProto.Parser.ParseDelimitedFrom(reader);
var metaData = metaDataProto.ToDemonstrationMetaData();
reader.Seek(DemonstrationStore.MetaDataBytes + 1, 0);
reader.Seek(DemonstrationWriter.MetaDataBytes + 1, 0);
var brainParamsProto = BrainParametersProto.Parser.ParseDelimitedFrom(reader);
var brainParameters = brainParamsProto.ToBrainParameters();

29
com.unity.ml-agents/Runtime/Academy.cs


{
const string k_ApiVersion = "API-15-dev0";
const int k_EditorTrainingPort = 5004;
internal const string k_portCommandLineFlag = "--mlagents-port";
/// <summary>
/// True if the Academy is initialized, false otherwise.
/// </summary>
/// <summary>
/// The singleton Academy object.
/// </summary>
/// <summary>
/// Collection of float properties (indexed by a string).
/// </summary>
public IFloatProperties FloatProperties;

// Signals to all the listeners that the academy is being destroyed
internal event Action DestroyAction;
// Signals the Agent that a new step is about to start.
// Signals the Agent that a new step is about to start.
// This will mark the Agent as Done if it has reached its maxSteps.
internal event Action AgentIncrementStep;

// Signals to all the agents each time the Academy force resets.
internal event Action AgentForceReset;
// Signals that the Academy has been reset by the training process
/// <summary>
/// Signals that the Academy has been reset by the training process.
/// </summary>
public event Action OnEnvironmentReset;
AcademyFixedUpdateStepper m_FixedUpdateStepper;

/// <summary>
/// Initialize the Academy if it hasn't already been initialized.
/// This method is always safe to call; it will have no effect if the Academy is already initialized.
/// This method is always safe to call; it will have no effect if the Academy is already
/// initialized.
/// </summary>
internal void LazyInitialize()
{

}
/// <summary>
/// Enable stepping of the Academy during the FixedUpdate phase. This is done by creating a temporary
/// GameObject with a MonoBehavior that calls Academy.EnvironmentStep().
/// Enable stepping of the Academy during the FixedUpdate phase. This is done by creating
/// a temporary GameObject with a MonoBehaviour that calls Academy.EnvironmentStep().
/// </summary>
void EnableAutomaticStepping()
{

/// Registers SideChannel to the Academy to send and receive data with Python.
/// If IsCommunicatorOn is false, the SideChannel will not be registered.
/// </summary>
/// <param name="sideChannel"> The side channel to be registered.</param>
/// <param name="channel"> The side channel to be registered.</param>
public void RegisterSideChannel(SideChannel channel)
{
LazyInitialize();

/// Unregisters SideChannel to the Academy. If the side channel was not registered,
/// nothing will happen.
/// </summary>
/// <param name="sideChannel"> The side channel to be unregistered.</param>
/// <param name="channel"> The side channel to be unregistered.</param>
public void UnregisterSideChannel(SideChannel channel)
{
Communicator?.UnregisterSideChannel(channel);

var inputPort = "";
for (var i = 0; i < args.Length; i++)
{
if (args[i] == "--port")
if (args[i] == k_portCommandLineFlag)
{
inputPort = args[i + 1];
}

135
com.unity.ml-agents/Runtime/Agent.cs


using System.Collections.Generic;
using UnityEngine;
using Barracuda;
using UnityEngine.Serialization;
/// observations, actions and current status, that is sent to the Brain.
/// observations, actions and current status.
public struct AgentInfo
internal struct AgentInfo
{
/// <summary>
/// Keeps track of the last vector action taken by the Brain.

public float[] vectorActions;
}
/// Agent Monobehavior class that is attached to a Unity GameObject, making it
/// Agent MonoBehaviour class that is attached to a Unity GameObject, making it
/// user in <see cref="CollectObservations"/>. On the other hand, actions
/// are determined by decisions produced by a Policy. Currently, this
/// class is expected to be extended to implement the desired agent behavior.
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/> or
/// <see cref="Agent.CollectObservations(VectorSensor, ActionMasker)"/>.
/// On the other hand, actions are determined by decisions produced by a Policy.
/// Currently, this class is expected to be extended to implement the desired agent behavior.
/// </summary>
/// <remarks>
/// Simply speaking, an agent roams through an environment and at each step

/// little may have changed between successive steps.
///
/// At any step, an agent may be considered <see cref="m_Done"/>.
/// This could occur due to a variety of reasons:
/// At any step, an agent may be considered done due to a variety of reasons:
/// - The agent reached an end state within its environment.
/// - The agent reached the maximum # of steps (i.e. timed out).
/// - The academy reached the maximum # of steps (forced agent to be done).

BehaviorParameters m_PolicyFactory;
/// This code is here to make the upgrade path for users using maxStep
/// easier. We will hook into the Serialization code and make sure that
/// easier. We will hook into the Serialization code and make sure that
/// agentParameters.maxStep and this.maxStep are in sync.
[Serializable]
internal struct AgentParameters

/// Whether or not the agent requests a decision.
bool m_RequestDecision;
/// Keeps track of the number of steps taken by the agent in this episode.
/// Note that this value is different for each agent, and may not overlap
/// with the step counter in the Academy, since agents reset based on

ActionMasker m_ActionMasker;
/// <summary>
/// Demonstration recorder.
/// Set of DemonstrationWriters that the Agent will write its step information to.
/// If you use a DemonstrationRecorder component, this will automatically register its DemonstrationWriter.
/// You can also add your own DemonstrationWriter by calling
/// DemonstrationRecorder.AddDemonstrationWriterToAgent()
DemonstrationRecorder m_Recorder;
internal ISet<DemonstrationWriter> DemonstrationWriters = new HashSet<DemonstrationWriter>();
/// <summary>
/// List of sensors used to generate observations.

/// </summary>
internal VectorSensor collectObservationsSensor;
/// MonoBehaviour function that is called when the attached GameObject
/// becomes enabled or active.
/// <summary>
/// <inheritdoc cref="OnBeforeSerialize"/>
/// </summary>
// Manages a serialization upgrade issue from v0.13 to v0.14 where maxStep moved
// from AgentParameters (since removed) to Agent
if (maxStep == 0 && maxStep != agentParameters.maxStep && !hasUpgradedFromAgentParameters)
{
maxStep = agentParameters.maxStep;

/// <summary>
/// <inheritdoc cref="OnAfterDeserialize"/>
/// </summary>
// Manages a serialization upgrade issue from v0.13 to v0.14 where maxStep moved
// from AgentParameters (since removed) to Agent
if (maxStep == 0 && maxStep != agentParameters.maxStep && !hasUpgradedFromAgentParameters)
{
maxStep = agentParameters.maxStep;

/// Helper method for the <see cref="OnEnable"/> event, created to
/// facilitate testing.
/// <summary>
/// Initializes the agent. Can be safely called multiple times.
/// </summary>
public void LazyInitialize()
{
if (m_Initialized)

// Grab the "static" properties for the Agent.
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
m_PolicyFactory = GetComponent<BehaviorParameters>();
m_Recorder = GetComponent<DemonstrationRecorder>();
m_Info = new AgentInfo();
m_Action = new AgentAction();

InitializeSensors();
}
/// Monobehavior function that is called when the attached GameObject
/// becomes disabled or inactive.
DemonstrationWriters.Clear();
// If Academy.Dispose has already been called, we don't need to unregister with it.
// We don't want to even try, because this will lazily create a new Academy!
if (Academy.IsInitialized)

// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors);
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
// We also have to write any to any DemonstationStores so that they get the "done" flag.
foreach(var demoWriter in DemonstrationWriters)
m_Recorder.WriteExperience(m_Info, sensors);
demoWriter.Record(m_Info, sensors);
}
UpdateRewardStats();

m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic);
}
/// <summary>
/// Returns the current step counter (within the current episode).
/// </summary>
/// <returns>

/// </returns>
public virtual float[] Heuristic()
{
throw new UnityAgentsException(string.Format(
throw new UnityAgentsException(
"{0} GameObject.",
gameObject.name));
$"{gameObject.name} GameObject.");
}
/// <summary>

collectObservationsSensor = new VectorSensor(param.vectorObservationSize);
if (param.numStackedVectorObservations > 1)
{
var stackingSensor = new StackingSensor(collectObservationsSensor, param.numStackedVectorObservations);
var stackingSensor = new StackingSensor(
collectObservationsSensor, param.numStackedVectorObservations);
sensors.Add(stackingSensor);
}
else

// Make sure the names are actually unique
for (var i = 0; i < sensors.Count - 1; i++)
{
Debug.Assert(!sensors[i].GetName().Equals(sensors[i + 1].GetName()), "Sensor names must be unique.");
Debug.Assert(
!sensors[i].GetName().Equals(sensors[i + 1].GetName()),
"Sensor names must be unique.");
}
#endif
}

m_Brain.RequestDecision(m_Info, sensors);
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
// If we have any DemonstrationWriters, write the AgentInfo and sensors to them.
foreach(var demoWriter in DemonstrationWriters)
m_Recorder.WriteExperience(m_Info, sensors);
demoWriter.Record(m_Info, sensors);
for (var i = 0; i < sensors.Count; i++)
foreach (var sensor in sensors)
sensors[i].Update();
sensor.Update();
/// <summary>
/// Collects the vector observations of the agent.

/// <param name="sensor">
/// The vector observations for the agent.
/// </param>
/// <remarks>
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its

/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="AddObservation(int)"/>
/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>
/// - <see cref="VectorSensor.AddObservation(int)"/>
/// - <see cref="VectorSensor.AddObservation(float)"/>
/// - <see cref="VectorSensor.AddObservation(Vector3)"/>
/// - <see cref="VectorSensor.AddObservation(Vector2)"/>
/// - <see cref="VectorSensor.AddObservation(Quaternion)"/>
/// - <see cref="VectorSensor.AddObservation(bool)"/>
/// - <see cref="VectorSensor.AddObservation(IEnumerable{float})"/>
/// - <see cref="VectorSensor.AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation

}
/// <summary>
/// Collects the vector observations of the agent.
/// Collects the vector observations of the agent alongside the masked actions.
/// <param name="sensor">
/// The vector observations for the agent.
/// </param>
/// <param name="actionMasker">
/// The masked actions for the agent.
/// </param>
/// <remarks>
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its

/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="AddObservation(int)"/>
/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>
/// - <see cref="VectorSensor.AddObservation(int)"/>
/// - <see cref="VectorSensor.AddObservation(float)"/>
/// - <see cref="VectorSensor.AddObservation(Vector3)"/>
/// - <see cref="VectorSensor.AddObservation(Vector2)"/>
/// - <see cref="VectorSensor.AddObservation(Quaternion)"/>
/// - <see cref="VectorSensor.AddObservation(bool)"/>
/// - <see cref="VectorSensor.AddObservation(IEnumerable{float})"/>
/// - <see cref="VectorSensor.AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation

/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it. You can call the following method on the ActionMasker
/// input :
/// - <see cref="SetActionMask(int branch, IEnumerable<int> actionIndices)"/>
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// - <see cref="SetActionMask(IEnumerable<int> actionIndices)"/>
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// - <see cref="ActionMasker.SetActionMask(int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, IEnumerable{int})"/>
/// - <see cref="ActionMasker.SetActionMask(IEnumerable{int})"/>
/// The branch input is the index of the action, actionIndices are the indices of the
/// invalid options for that action.
/// </remarks>

}
/// <summary>
/// Returns the last action that was decided on by the Agent (returns null if no decision has been made)
/// Returns the last action that was decided on by the Agent
/// <returns>
/// The last action that was decided by the Agent (or null if no decision has been made)
/// </returns>
public float[] GetAction()
{
return m_Action.vectorActions;

/// <param name="min"></param>
/// <param name="max"></param>
/// <returns></returns>
protected float ScaleAction(float rawAction, float min, float max)
protected static float ScaleAction(float rawAction, float min, float max)
{
var middle = (min + max) / 2;
var range = (max - min) / 2;

2
com.unity.ml-agents/Runtime/InferenceBrain/TensorProxy.cs


/// allowing the user to specify everything but the data in a graphical way.
/// </summary>
[Serializable]
public class TensorProxy
internal class TensorProxy
{
public enum TensorType
{

4
com.unity.ml-agents/Runtime/Sensor/ISensor.cs


/// Note that this (and GetCompressedObservation) may be called multiple times per agent step, so should not
/// mutate any internal state.
/// </summary>
/// <param name="adapater"></param>
/// <param name="adapter"></param>
int Write(WriteAdapter adapater);
int Write(WriteAdapter adapter);
/// <summary>
/// Return a compressed representation of the observation. For small observations, this should generally not be

534
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs


namespace MLAgents
{
public class RayPerceptionSensor : ISensor
/// <summary>
/// Determines which dimensions the sensor will perform the casts in.
/// </summary>
public enum RayPerceptionCastType
{
/// Cast in 2 dimensions, using Physics2D.CircleCast or Physics2D.RayCast.
Cast2D,
/// Cast in 3 dimensions, using Physics.SphereCast or Physics.RayCast.
Cast3D,
}
public struct RayPerceptionInput
public enum CastType
{
Cast2D,
Cast3D,
}
/// <summary>
/// Length of the rays to cast. This will be scaled up or down based on the scale of the transform.
/// </summary>
public float rayLength;
/// <summary>
/// List of tags which correspond to object types agent can see.
/// </summary>
public IReadOnlyList<string> detectableTags;
/// <summary>
/// List of angles (in degrees) used to define the rays.
/// 90 degrees is considered "forward" relative to the game object.
/// </summary>
public IReadOnlyList<float> angles;
/// <summary>
/// Starting height offset of ray from center of agent
/// </summary>
public float startOffset;
float[] m_Observations;
int[] m_Shape;
string m_Name;
/// <summary>
/// Ending height offset of ray from center of agent.
/// </summary>
public float endOffset;
float m_RayDistance;
List<string> m_DetectableObjects;
float[] m_Angles;
/// <summary>
/// Radius of the sphere to use for spherecasting.
/// If 0 or less, rays are used instead - this may be faster, especially for complex environments.
/// </summary>
public float castRadius;
float m_StartOffset;
float m_EndOffset;
float m_CastRadius;
CastType m_CastType;
Transform m_Transform;
int m_LayerMask;
/// <summary>
/// Transform of the GameObject.
/// </summary>
public Transform transform;
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// Whether to perform the casts in 2D or 3D.
public class DebugDisplayInfo
public RayPerceptionCastType castType;
/// <summary>
/// Filtering options for the casts.
/// </summary>
public int layerMask;
/// <summary>
/// Returns the expected number of floats in the output.
/// </summary>
/// <returns></returns>
public int OutputSize()
public struct RayInfo
return (detectableTags.Count + 2) * angles.Count;
}
/// <summary>
/// Get the cast start and end points for the given ray index/
/// </summary>
/// <param name="rayIndex"></param>
/// <returns>A tuple of the start and end positions in world space.</returns>
public (Vector3 StartPositionWorld, Vector3 EndPositionWorld) RayExtents(int rayIndex)
{
var angle = angles[rayIndex];
Vector3 startPositionLocal, endPositionLocal;
if (castType == RayPerceptionCastType.Cast3D)
public Vector3 localStart;
public Vector3 localEnd;
public Vector3 worldStart;
public Vector3 worldEnd;
public bool castHit;
public float hitFraction;
public float castRadius;
startPositionLocal = new Vector3(0, startOffset, 0);
endPositionLocal = PolarToCartesian3D(rayLength, angle);
endPositionLocal.y += endOffset;
public void Reset()
else
m_Frame = Time.frameCount;
// Vector2s here get converted to Vector3s (and back to Vector2s for casting)
startPositionLocal = new Vector2();
endPositionLocal = PolarToCartesian2D(rayLength, angle);
var startPositionWorld = transform.TransformPoint(startPositionLocal);
var endPositionWorld = transform.TransformPoint(endPositionLocal);
return (StartPositionWorld: startPositionWorld, EndPositionWorld: endPositionWorld);
}
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static internal Vector3 PolarToCartesian3D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector3(x, 0f, z);
}
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static internal Vector2 PolarToCartesian2D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector2(x, y);
}
}
public class RayPerceptionOutput
{
public struct RayOutput
{
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// Whether or not the ray hit anything.
/// </summary>
public bool hasHit;
/// <summary>
/// Whether or not the ray hit an object whose tag is in the input's detectableTags list.
/// </summary>
public bool hitTaggedObject;
/// <summary>
/// The index of the hit object's tag in the detectableTags list, or -1 if there was no hit, or the
/// hit object has a different tag.
/// </summary>
public int hitTagIndex;
/// <summary>
/// Normalized distance to the hit object.
/// </summary>
public float hitFraction;
/// <summary>
/// Writes the ray output information to a subset of the float array. Each element in the rayAngles array
/// determines a sublist of data to the observation. The sublist contains the observation data for a single cast.
/// The list is composed of the following:
/// 1. A one-hot encoding for detectable tags. For example, if detectableTags.Length = n, the
/// first n elements of the sublist will be a one-hot encoding of the detectableTag that was hit, or
/// all zeroes otherwise.
/// 2. The 'numDetectableTags' element of the sublist will be 1 if the ray missed everything, or 0 if it hit
/// something (detectable or not).
/// 3. The 'numDetectableTags+1' element of the sublist will contain the normalized distance to the object
/// hit, or 1.0 if nothing was hit.
public int age
/// <param name="numDetectableTags"></param>
/// <param name="rayIndex"></param>
/// <param name="buffer">Output buffer. The size must be equal to (numDetectableTags+2) * rayOutputs.Length</param>
public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer)
get { return Time.frameCount - m_Frame; }
var bufferOffset = (numDetectableTags + 2) * rayIndex;
if (hitTaggedObject)
{
buffer[bufferOffset + hitTagIndex] = 1f;
}
buffer[bufferOffset + numDetectableTags] = hasHit ? 0f : 1f;
buffer[bufferOffset + numDetectableTags + 1] = hitFraction;
}
/// <summary>
/// RayOutput for each ray that was cast.
/// </summary>
public RayOutput[] rayOutputs;
}
public RayInfo[] rayInfos;
/// <summary>
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// </summary>
internal class DebugDisplayInfo
{
public struct RayInfo
{
public Vector3 worldStart;
public Vector3 worldEnd;
public float castRadius;
public RayPerceptionOutput.RayOutput rayOutput;
}
public void Reset()
{
m_Frame = Time.frameCount;
}
int m_Frame;
/// <summary>
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// </summary>
public int age
{
get { return Time.frameCount - m_Frame; }
public RayInfo[] rayInfos;
int m_Frame;
}
public class RayPerceptionSensor : ISensor
{
float[] m_Observations;
int[] m_Shape;
string m_Name;
RayPerceptionInput m_RayPerceptionInput;
public DebugDisplayInfo debugDisplayInfo
internal DebugDisplayInfo debugDisplayInfo
public RayPerceptionSensor(string name, float rayDistance, List<string> detectableObjects, float[] angles,
Transform transform, float startOffset, float endOffset, float castRadius, CastType castType,
int rayLayerMask)
public RayPerceptionSensor(string name, RayPerceptionInput rayInput)
var numObservations = (detectableObjects.Count + 2) * angles.Length;
var numObservations = rayInput.OutputSize();
m_RayPerceptionInput = rayInput;
m_RayDistance = rayDistance;
m_DetectableObjects = detectableObjects;
// TODO - preprocess angles, save ray directions instead?
m_Angles = angles;
m_Transform = transform;
m_StartOffset = startOffset;
m_EndOffset = endOffset;
m_CastRadius = castRadius;
m_CastType = castType;
m_LayerMask = rayLayerMask;
if (Application.isEditor)
{

{
using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive"))
{
PerceiveStatic(
m_RayDistance, m_Angles, m_DetectableObjects, m_StartOffset, m_EndOffset,
m_CastRadius, m_Transform, m_CastType, m_Observations, m_LayerMask,
m_DebugDisplayInfo
);
Array.Clear(m_Observations, 0, m_Observations.Length);
var numRays = m_RayPerceptionInput.angles.Count;
var numDetectableTags = m_RayPerceptionInput.detectableTags.Count;
if (m_DebugDisplayInfo != null)
{
// Reset the age information, and resize the buffer if needed.
m_DebugDisplayInfo.Reset();
if (m_DebugDisplayInfo.rayInfos == null || m_DebugDisplayInfo.rayInfos.Length != numRays)
{
m_DebugDisplayInfo.rayInfos = new DebugDisplayInfo.RayInfo[numRays];
}
}
// For each ray, do the casting, and write the information to the observation buffer
for (var rayIndex = 0; rayIndex < numRays; rayIndex++)
{
DebugDisplayInfo.RayInfo debugRay;
var rayOutput = PerceiveSingleRay(m_RayPerceptionInput, rayIndex, out debugRay);
if (m_DebugDisplayInfo != null)
{
m_DebugDisplayInfo.rayInfos[rayIndex] = debugRay;
}
rayOutput.ToFloatArray(numDetectableTags, rayIndex, m_Observations);
}
// Finally, add the observations to the WriteAdapter
adapter.AddRange(m_Observations);
}
return m_Observations.Length;

}
/// <summary>
/// Evaluates a perception vector to be used as part of an observation of an agent.
/// Each element in the rayAngles array determines a sublist of data to the observation.
/// The sublist contains the observation data for a single cast. The list is composed of the following:
/// 1. A one-hot encoding for detectable objects. For example, if detectableObjects.Length = n, the
/// first n elements of the sublist will be a one-hot encoding of the detectableObject that was hit, or
/// all zeroes otherwise.
/// 2. The 'length' element of the sublist will be 1 if the ray missed everything, or 0 if it hit
/// something (detectable or not).
/// 3. The 'length+1' element of the sublist will contain the normalised distance to the object hit, or 1 if
/// nothing was hit.
///
/// Evaluates the raycasts to be used as part of an observation of an agent.
/// <param name="unscaledRayLength"></param>
/// <param name="rayAngles">List of angles (in degrees) used to define the rays. 90 degrees is considered
/// "forward" relative to the game object</param>
/// <param name="detectableObjects">List of tags which correspond to object types agent can see</param>
/// <param name="startOffset">Starting height offset of ray from center of agent.</param>
/// <param name="endOffset">Ending height offset of ray from center of agent.</param>
/// <param name="unscaledCastRadius">Radius of the sphere to use for spherecasting. If 0 or less, rays are used
/// instead - this may be faster, especially for complex environments.</param>
/// <param name="transform">Transform of the GameObject</param>
/// <param name="castType">Whether to perform the casts in 2D or 3D.</param>
/// <param name="perceptionBuffer">Output array of floats. Must be (num rays) * (num tags + 2) in size.</param>
/// <param name="layerMask">Filtering options for the casts</param>
/// <param name="debugInfo">Optional debug information output, only used by RayPerceptionSensor.</param>
/// <param name="input">Input defining the rays that will be cast.</param>
/// <param name="output">Output class that will be written to with raycast results.</param>
public static void PerceiveStatic(float unscaledRayLength,
IReadOnlyList<float> rayAngles, IReadOnlyList<string> detectableObjects,
float startOffset, float endOffset, float unscaledCastRadius,
Transform transform, CastType castType, float[] perceptionBuffer,
int layerMask = Physics.DefaultRaycastLayers,
DebugDisplayInfo debugInfo = null)
public static RayPerceptionOutput PerceiveStatic(RayPerceptionInput input)
Array.Clear(perceptionBuffer, 0, perceptionBuffer.Length);
if (debugInfo != null)
RayPerceptionOutput output = new RayPerceptionOutput();
output.rayOutputs = new RayPerceptionOutput.RayOutput[input.angles.Count];
for (var rayIndex = 0; rayIndex < input.angles.Count; rayIndex++)
debugInfo.Reset();
if (debugInfo.rayInfos == null || debugInfo.rayInfos.Length != rayAngles.Count)
{
debugInfo.rayInfos = new DebugDisplayInfo.RayInfo[rayAngles.Count];
}
DebugDisplayInfo.RayInfo debugRay;
output.rayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex, out debugRay);
// For each ray sublist stores categorical information on detected object
// along with object distance.
int bufferOffset = 0;
for (var rayIndex = 0; rayIndex < rayAngles.Count; rayIndex++)
{
var angle = rayAngles[rayIndex];
Vector3 startPositionLocal, endPositionLocal;
if (castType == CastType.Cast3D)
{
startPositionLocal = new Vector3(0, startOffset, 0);
endPositionLocal = PolarToCartesian3D(unscaledRayLength, angle);
endPositionLocal.y += endOffset;
}
else
{
// Vector2s here get converted to Vector3s (and back to Vector2s for casting)
startPositionLocal = new Vector2();
endPositionLocal = PolarToCartesian2D(unscaledRayLength, angle);
}
return output;
}
var startPositionWorld = transform.TransformPoint(startPositionLocal);
var endPositionWorld = transform.TransformPoint(endPositionLocal);
/// <summary>
/// Evaluate the raycast results of a single ray from the RayPerceptionInput.
/// </summary>
/// <param name="input"></param>
/// <param name="rayIndex"></param>
/// <param name="debugRayOut"></param>
/// <returns></returns>
static RayPerceptionOutput.RayOutput PerceiveSingleRay(
RayPerceptionInput input,
int rayIndex,
out DebugDisplayInfo.RayInfo debugRayOut
)
{
var unscaledRayLength = input.rayLength;
var unscaledCastRadius = input.castRadius;
var rayDirection = endPositionWorld - startPositionWorld;
// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection.magnitude;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ? unscaledCastRadius * scaledRayLength / unscaledRayLength : unscaledCastRadius;
var extents = input.RayExtents(rayIndex);
var startPositionWorld = extents.StartPositionWorld;
var endPositionWorld = extents.EndPositionWorld;
// Do the cast and assign the hit information for each detectable object.
// sublist[0 ] <- did hit detectableObjects[0]
// ...
// sublist[numObjects-1] <- did hit detectableObjects[numObjects-1]
// sublist[numObjects ] <- 1 if missed else 0
// sublist[numObjects+1] <- hit fraction (or 1 if no hit)
var rayDirection = endPositionWorld - startPositionWorld;
// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection.magnitude;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ?
unscaledCastRadius * scaledRayLength / unscaledRayLength :
unscaledCastRadius;
bool castHit;
float hitFraction;
GameObject hitObject;
// Do the cast and assign the hit information for each detectable tag.
bool castHit;
float hitFraction;
GameObject hitObject;
if (castType == CastType.Cast3D)
if (input.castType == RayPerceptionCastType.Cast3D)
{
RaycastHit rayHit;
if (scaledCastRadius > 0f)
RaycastHit rayHit;
if (scaledCastRadius > 0f)
{
castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit,
scaledRayLength, layerMask);
}
else
{
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit,
scaledRayLength, layerMask);
}
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit,
scaledRayLength, input.layerMask);
RaycastHit2D rayHit;
if (scaledCastRadius > 0f)
{
rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection,
scaledRayLength, layerMask);
}
else
{
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, layerMask);
}
castHit = rayHit;
hitFraction = castHit ? rayHit.fraction : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit,
scaledRayLength, input.layerMask);
if (debugInfo != null)
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
}
else
{
RaycastHit2D rayHit;
if (scaledCastRadius > 0f)
debugInfo.rayInfos[rayIndex].localStart = startPositionLocal;
debugInfo.rayInfos[rayIndex].localEnd = endPositionLocal;
debugInfo.rayInfos[rayIndex].worldStart = startPositionWorld;
debugInfo.rayInfos[rayIndex].worldEnd = endPositionWorld;
debugInfo.rayInfos[rayIndex].castHit = castHit;
debugInfo.rayInfos[rayIndex].hitFraction = hitFraction;
debugInfo.rayInfos[rayIndex].castRadius = scaledCastRadius;
rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection,
scaledRayLength, input.layerMask);
else if (Application.isEditor)
else
// Legacy drawing
Debug.DrawRay(startPositionWorld, rayDirection, Color.black, 0.01f, true);
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, input.layerMask);
if (castHit)
castHit = rayHit;
hitFraction = castHit ? rayHit.fraction : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
}
var rayOutput = new RayPerceptionOutput.RayOutput
{
hasHit = castHit,
hitFraction = hitFraction,
hitTaggedObject = false,
hitTagIndex = -1
};
if (castHit)
{
// Find the index of the tag of the object that was hit.
for (var i = 0; i < input.detectableTags.Count; i++)
bool hitTaggedObject = false;
for (var i = 0; i < detectableObjects.Count; i++)
if (hitObject.CompareTag(input.detectableTags[i]))
if (hitObject.CompareTag(detectableObjects[i]))
{
perceptionBuffer[bufferOffset + i] = 1;
perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = hitFraction;
hitTaggedObject = true;
break;
}
}
if (!hitTaggedObject)
{
// Something was hit but not on the list. Still set the hit fraction.
perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = hitFraction;
rayOutput.hitTaggedObject = true;
rayOutput.hitTagIndex = i;
break;
else
{
perceptionBuffer[bufferOffset + detectableObjects.Count] = 1f;
// Nothing was hit, so there's full clearance in front of the agent.
perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = 1.0f;
}
}
bufferOffset += detectableObjects.Count + 2;
}
}
debugRayOut.worldStart = startPositionWorld;
debugRayOut.worldEnd = endPositionWorld;
debugRayOut.rayOutput = rayOutput;
debugRayOut.castRadius = scaledCastRadius;
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static Vector3 PolarToCartesian3D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector3(x, 0f, z);
}
return rayOutput;
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static Vector2 PolarToCartesian2D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector2(x, y);
}
}
}

4
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs


rayLayerMask = Physics2D.DefaultRaycastLayers;
}
public override RayPerceptionSensor.CastType GetCastType()
public override RayPerceptionCastType GetCastType()
return RayPerceptionSensor.CastType.Cast2D;
return RayPerceptionCastType.Cast2D;
}
}
}

4
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs


[Tooltip("Ray end is offset up or down by this amount.")]
public float endVerticalOffset;
public override RayPerceptionSensor.CastType GetCastType()
public override RayPerceptionCastType GetCastType()
return RayPerceptionSensor.CastType.Cast3D;
return RayPerceptionCastType.Cast3D;
}
public override float GetStartVerticalOffset()

38
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs


[Header("Debug Gizmos", order = 999)]
public Color rayHitColor = Color.red;
public Color rayMissColor = Color.white;
[Tooltip("Whether to draw the raycasts in the world space of when they happened, or using the Agent's current transform'")]
public bool useWorldPositions = true;
public abstract RayPerceptionSensor.CastType GetCastType();
public abstract RayPerceptionCastType GetCastType();
public virtual float GetStartVerticalOffset()
{

public override ISensor CreateSensor()
{
var rayAngles = GetRayAngles(raysPerDirection, maxRayDegrees);
m_RaySensor = new RayPerceptionSensor(sensorName, rayLength, detectableTags, rayAngles,
transform, GetStartVerticalOffset(), GetEndVerticalOffset(), sphereCastRadius, GetCastType(),
rayLayerMask
);
var rayPerceptionInput = new RayPerceptionInput();
rayPerceptionInput.rayLength = rayLength;
rayPerceptionInput.detectableTags = detectableTags;
rayPerceptionInput.angles = rayAngles;
rayPerceptionInput.startOffset = GetStartVerticalOffset();
rayPerceptionInput.endOffset = GetEndVerticalOffset();
rayPerceptionInput.castRadius = sphereCastRadius;
rayPerceptionInput.transform = transform;
rayPerceptionInput.castType = GetCastType();
rayPerceptionInput.layerMask = rayLayerMask;
m_RaySensor = new RayPerceptionSensor(sensorName, rayPerceptionInput);
if (observationStacks != 1)
{

public override int[] GetObservationShape()
{
var numRays = 2 * raysPerDirection + 1;
var numTags = detectableTags == null ? 0 : detectableTags.Count;
var numTags = detectableTags?.Count ?? 0;
var obsSize = (numTags + 2) * numRays;
var stacks = observationStacks > 1 ? observationStacks : 1;
return new[] { obsSize * stacks };

foreach (var rayInfo in debugInfo.rayInfos)
{
// Either use the original world-space coordinates of the raycast, or transform the agent-local
// coordinates of the rays to the current transform of the agent. If the agent acts every frame,
// these should be the same.
if (!useWorldPositions)
{
startPositionWorld = transform.TransformPoint(rayInfo.localStart);
endPositionWorld = transform.TransformPoint(rayInfo.localEnd);
}
rayDirection *= rayInfo.hitFraction;
rayDirection *= rayInfo.rayOutput.hitFraction;
var lerpT = rayInfo.hitFraction * rayInfo.hitFraction;
var lerpT = rayInfo.rayOutput.hitFraction * rayInfo.rayOutput.hitFraction;
var color = Color.Lerp(rayHitColor, rayMissColor, lerpT);
color.a *= alpha;
Gizmos.color = color;

if (rayInfo.castHit)
if (rayInfo.rayOutput.hasHit)
{
var hitRadius = Mathf.Max(rayInfo.castRadius, .05f);
Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius);

6
com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs


TensorShape m_TensorShape;
internal WriteAdapter() { }
/// <summary>
/// Set the adapter to write to an IList at the given channelOffset.
/// </summary>

public void SetTarget(IList<float> data, int[] shape, int offset)
internal void SetTarget(IList<float> data, int[] shape, int offset)
{
m_Data = data;
m_Offset = offset;

/// <param name="tensorProxy">Tensor proxy that will be writtent to.</param>
/// <param name="batchIndex">Batch index in the tensor proxy (i.e. the index of the Agent)</param>
/// <param name="channelOffset">Offset from the start of the channel to write to.</param>
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
internal void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
{
m_Proxy = tensorProxy;
m_Batch = batchIndex;

50
com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs


[TestFixture]
public class DemonstrationTests : MonoBehaviour
{
const string k_DemoDirecory = "Assets/Demonstrations/";
const string k_DemoDirectory = "Assets/Demonstrations/";
const string k_ExtensionType = ".demo";
const string k_DemoName = "Test";

public void TestStoreInitalize()
{
var fileSystem = new MockFileSystem();
var demoStore = new DemonstrationStore(fileSystem);
Assert.IsFalse(fileSystem.Directory.Exists(k_DemoDirecory));
var gameobj = new GameObject("gameObj");
var brainParameters = new BrainParameters
{
vectorObservationSize = 3,
numStackedVectorObservations = 2,
vectorActionDescriptions = new[] { "TestActionA", "TestActionB" },
vectorActionSize = new[] { 2, 2 },
vectorActionSpaceType = SpaceType.Discrete
};
var bp = gameobj.AddComponent<BehaviorParameters>();
bp.brainParameters.vectorObservationSize = 3;
bp.brainParameters.numStackedVectorObservations = 2;
bp.brainParameters.vectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
bp.brainParameters.vectorActionSize = new[] { 2, 2 };
bp.brainParameters.vectorActionSpaceType = SpaceType.Discrete;
demoStore.Initialize(k_DemoName, brainParameters, "TestBrain");
var agent = gameobj.AddComponent<TestAgent>();
Assert.IsTrue(fileSystem.Directory.Exists(k_DemoDirecory));
Assert.IsTrue(fileSystem.FileExists(k_DemoDirecory + k_DemoName + k_ExtensionType));
Assert.IsFalse(fileSystem.Directory.Exists(k_DemoDirectory));
var demoRec = gameobj.AddComponent<DemonstrationRecorder>();
demoRec.record = true;
demoRec.demonstrationName = k_DemoName;
demoRec.demonstrationDirectory = k_DemoDirectory;
var demoWriter = demoRec.LazyInitialize(fileSystem);
Assert.IsTrue(fileSystem.Directory.Exists(k_DemoDirectory));
Assert.IsTrue(fileSystem.FileExists(k_DemoDirectory + k_DemoName + k_ExtensionType));
var agentInfo = new AgentInfo
{

storedVectorActions = new[] { 0f, 1f },
};
demoStore.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
demoStore.Close();
demoWriter.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
demoRec.Close();
// Make sure close can be called multiple times
demoWriter.Close();
demoRec.Close();
// Make sure trying to write after closing doesn't raise an error.
demoWriter.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
}
public class ObservationAgent : TestAgent

agentGo1.AddComponent<DemonstrationRecorder>();
var demoRecorder = agentGo1.GetComponent<DemonstrationRecorder>();
var fileSystem = new MockFileSystem();
demoRecorder.demonstrationDirectory = k_DemoDirectory;
demoRecorder.InitializeDemoStore(fileSystem);
demoRecorder.LazyInitialize(fileSystem);
var agentEnableMethod = typeof(Agent).GetMethod("OnEnable",
BindingFlags.Instance | BindingFlags.NonPublic);

// Read back the demo file and make sure observations were written
var reader = fileSystem.File.OpenRead("Assets/Demonstrations/TestBrain.demo");
reader.Seek(DemonstrationStore.MetaDataBytes + 1, 0);
reader.Seek(DemonstrationWriter.MetaDataBytes + 1, 0);
BrainParametersProto.Parser.ParseDelimitedFrom(reader);
var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo;

2
config/gail_config.yaml


strength: 1.0
gamma: 0.99
encoding_size: 128
demo_path: Project/Assets/Demonstrations/PushblockDemo.demo
demo_path: Project/Assets/ML-Agents/Examples/PushBlock/Demos/ExpertPush.demo
Hallway:
use_recurrent: true

9
docs/API-Reference.md


# API Reference
Our developer-facing C# classes (Academy, Agent, Decision and Monitor) have been
documented to be compatible with Doxygen for auto-generating HTML
documentation.
Our developer-facing C# classes have been documented to be compatible with
Doxygen for auto-generating HTML documentation.
To generate the API reference, download Doxygen
and run the following command within the `docs/` directory:

subdirectory to navigate to the API reference home. Note that `html/` is already
included in the repository's `.gitignore` file.
In the near future, we aim to expand our documentation to include all the Unity
C# classes and Python API.
In the near future, we aim to expand our documentation to include the Python
classes.

4
docs/Migrating.md


### Important changes
* The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed.
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)
* The interface for `RayPerceptionSensor.PerceiveStatic()` was changed to take an input class and write to an output class.
* The `--multi-gpu` option has been removed temporarily.
* If you call `RayPerceptionSensor.PerceiveStatic()` manually, add your inputs to a `RayPerceptionInput`. To get the previous float array output, use `RayPerceptionOutput.ToFloatArray()`
* Re-import all of your `*.NN` files to work with the updated Barracuda package.
* Replace all calls to `Agent.GetStepCount()` with `Agent.StepCount`

2
docs/Training-Imitation-Learning.md


from a few minutes or a few hours of demonstration data may be necessary to
be useful for imitation learning. When you have recorded enough data, end
the Editor play session, and a `.demo` file will be created in the
`Assets/Demonstrations` folder. This file contains the demonstrations.
`Assets/Demonstrations` folder (by default). This file contains the demonstrations.
Clicking on the file will provide metadata about the demonstration in the
inspector.

133
docs/dox-ml-agents.conf


# Doxyfile 1.8.13
# To generate the C# API documentation, run:
#
#
# doxygen dox-ml-agents.conf
#
# from the ml-agents-docs directory

# title of most generated pages and in a few other places.
# The default value is: My Project.
PROJECT_NAME = "ML-Agents Toolkit"
PROJECT_NAME = "Unity ML-Agents Toolkit"
PROJECT_NUMBER = v0.4
PROJECT_NUMBER =
PROJECT_BRIEF =
PROJECT_BRIEF =
# With the PROJECT_LOGO tag one can specify a logo or an icon that is included
# in the documentation. The maximum height of the logo should not exceed 55

# entered, it will be relative to the location where doxygen was started. If
# left blank the current directory will be used.
OUTPUT_DIRECTORY =
OUTPUT_DIRECTORY =
# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
# directories (in 2 levels) under the output directory of each output format and

# will be relative from the directory where doxygen is started.
# This tag requires that the tag FULL_PATH_NAMES is set to YES.
STRIP_FROM_PATH =
STRIP_FROM_PATH =
# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the
# path mentioned in the documentation of a class, which tells the reader which

# using the -I flag.
STRIP_FROM_INC_PATH =
STRIP_FROM_INC_PATH =
# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but
# less readable) file names. This can be useful is your file systems doesn't

# "Side Effects:". You can put \n's in the value part of an alias to insert
# newlines.
ALIASES =
ALIASES =
TCL_SUBST =
TCL_SUBST =
# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources
# only. Doxygen will then generate output that is more tailored for C. For

# Note that for custom extensions you also need to set FILE_PATTERNS otherwise
# the files are not read by doxygen.
EXTENSION_MAPPING =
EXTENSION_MAPPING =
# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments
# according to the Markdown format, which allows for more readable

# sections, marked by \if <section_label> ... \endif and \cond <section_label>
# ... \endcond blocks.
ENABLED_SECTIONS =
ENABLED_SECTIONS =
# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the
# initial value of a variable or macro / define can have for it to appear in the

# by doxygen. Whatever the program writes to standard output is used as the file
# version. For an example see the documentation.
FILE_VERSION_FILTER =
FILE_VERSION_FILTER =
# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed
# by doxygen. The layout file controls the global structure of the generated

# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the
# search path. See also \cite for info how to create references.
CITE_BIB_FILES =
CITE_BIB_FILES =
#---------------------------------------------------------------------------
# Configuration options related to warning and progress messages

# messages should be written. If left blank the output is written to standard
# error (stderr).
WARN_LOGFILE =
WARN_LOGFILE =
#---------------------------------------------------------------------------
# Configuration options related to the input files

# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
# Note: If this tag is empty the current directory is searched.
INPUT = ../Project/Assets/ML-Agents/Scripts/Academy.cs \
../Project/Assets/ML-Agents/Scripts/Agent.cs \
../Project/Assets/ML-Agents/Scripts/Monitor.cs \
../Project/Assets/ML-Agents/Scripts/Decision.cs
INPUT = ../com.unity.ml-agents/Runtime/
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses

# Note that relative paths are relative to the directory from which doxygen is
# run.
EXCLUDE =
EXCLUDE =
# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or
# directories that are symbolic links (a Unix file system feature) are excluded

# Note that the wildcards are matched against the file with absolute path, so to
# exclude all test directories for example use the pattern */test/*
EXCLUDE_PATTERNS =
EXCLUDE_PATTERNS =
# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names
# (namespaces, classes, functions, etc.) that should be excluded from the

# Note that the wildcards are matched against the file with absolute path, so to
# exclude all test directories use the pattern */test/*
EXCLUDE_SYMBOLS =
EXCLUDE_SYMBOLS =
EXAMPLE_PATH =
EXAMPLE_PATH =
# If the value of the EXAMPLE_PATH tag contains directories, you can use the
# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and

EXAMPLE_PATTERNS =
EXAMPLE_PATTERNS =
# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be
# searched for input files to be used with the \include or \dontinclude commands

# need to set EXTENSION_MAPPING for the extension otherwise the files are not
# properly processed by doxygen.
INPUT_FILTER =
INPUT_FILTER =
# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern
# basis. Doxygen will compare the file name with each pattern and apply the

# need to set EXTENSION_MAPPING for the extension otherwise the files are not
# properly processed by doxygen.
FILTER_PATTERNS =
FILTER_PATTERNS =
# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using
# INPUT_FILTER) will also be used to filter the input files that are used for

# *.ext= (so without naming a filter).
# This tag requires that the tag FILTER_SOURCE_FILES is set to YES.
FILTER_SOURCE_PATTERNS =
FILTER_SOURCE_PATTERNS =
# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that
# is part of the input, its contents will be placed on the main page

# generated with the -Duse-libclang=ON option for CMake.
# The default value is: NO.
CLANG_ASSISTED_PARSING = NO
#CLANG_ASSISTED_PARSING = NO
# If clang assisted parsing is enabled you can provide the compiler with command
# line options that you would normally use when invoking the compiler. Note that

CLANG_OPTIONS =
#CLANG_OPTIONS =
#---------------------------------------------------------------------------
# Configuration options related to the alphabetical class index

# while generating the index headers.
# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
IGNORE_PREFIX =
IGNORE_PREFIX =
#---------------------------------------------------------------------------
# Configuration options related to the HTML output

# files will be copied as-is; there are no commands or markers available.
# This tag requires that the tag GENERATE_HTML is set to YES.
HTML_EXTRA_FILES =
HTML_EXTRA_FILES =
# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen
# will adjust the colors in the style sheet and background images according to

# written to the html output directory.
# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
CHM_FILE =
CHM_FILE =
# The HHC_LOCATION tag can be used to specify the location (absolute path
# including file name) of the HTML help compiler (hhc.exe). If non-empty,

HHC_LOCATION =
HHC_LOCATION =
# The GENERATE_CHI flag controls if a separate .chi index file is generated
# (YES) or that it should be included in the master .chm file (NO).

# and project file content.
# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
CHM_INDEX_ENCODING =
CHM_INDEX_ENCODING =
# The BINARY_TOC flag controls whether a binary table of contents is generated
# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it

# the HTML output folder.
# This tag requires that the tag GENERATE_QHP is set to YES.
QCH_FILE =
QCH_FILE =
# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help
# Project output. For more information please see Qt Help Project / Namespace

# filters).
# This tag requires that the tag GENERATE_QHP is set to YES.
QHP_CUST_FILTER_NAME =
QHP_CUST_FILTER_NAME =
# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the
# custom filter to add. For more information please see Qt Help Project / Custom

QHP_CUST_FILTER_ATTRS =
QHP_CUST_FILTER_ATTRS =
# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this
# project's filter section matches. Qt Help Project / Filter Attributes (see:

QHP_SECT_FILTER_ATTRS =
QHP_SECT_FILTER_ATTRS =
# The QHG_LOCATION tag can be used to specify the location of Qt's
# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the

QHG_LOCATION =
QHG_LOCATION =
# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be
# generated, together with the HTML files, they form an Eclipse help plugin. To

# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols
# This tag requires that the tag USE_MATHJAX is set to YES.
MATHJAX_EXTENSIONS =
MATHJAX_EXTENSIONS =
# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces
# of code that will be used on startup of the MathJax code. See the MathJax site

MATHJAX_CODEFILE =
MATHJAX_CODEFILE =
# When the SEARCHENGINE tag is enabled doxygen will generate a search box for
# the HTML output. The underlying search engine uses javascript and DHTML and

# Searching" for details.
# This tag requires that the tag SEARCHENGINE is set to YES.
SEARCHENGINE_URL =
SEARCHENGINE_URL =
# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the unindexed
# search data is written to a file for indexing by an external tool. With the

# projects and redirect the results back to the right project.
# This tag requires that the tag SEARCHENGINE is set to YES.
EXTERNAL_SEARCH_ID =
EXTERNAL_SEARCH_ID =
# The EXTRA_SEARCH_MAPPINGS tag can be used to enable searching through doxygen
# projects other than the one defined by this configuration file, but that are

# EXTRA_SEARCH_MAPPINGS = tagname1=loc1 tagname2=loc2 ...
# This tag requires that the tag SEARCHENGINE is set to YES.
EXTRA_SEARCH_MAPPINGS =
EXTRA_SEARCH_MAPPINGS =
#---------------------------------------------------------------------------
# Configuration options related to the LaTeX output

# If left blank no extra packages will be included.
# This tag requires that the tag GENERATE_LATEX is set to YES.
EXTRA_PACKAGES =
EXTRA_PACKAGES =
# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the
# generated LaTeX document. The header should contain everything until the first

# to HTML_HEADER.
# This tag requires that the tag GENERATE_LATEX is set to YES.
LATEX_HEADER =
LATEX_HEADER =
# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the
# generated LaTeX document. The footer should contain everything after the last

# Note: Only use a user-defined footer if you know what you are doing!
# This tag requires that the tag GENERATE_LATEX is set to YES.
LATEX_FOOTER =
LATEX_FOOTER =
# The LATEX_EXTRA_STYLESHEET tag can be used to specify additional user-defined
# LaTeX style sheets that are included after the standard style sheets created

# list).
# This tag requires that the tag GENERATE_LATEX is set to YES.
LATEX_EXTRA_STYLESHEET =
LATEX_EXTRA_STYLESHEET =
# The LATEX_EXTRA_FILES tag can be used to specify one or more extra images or
# other source files which should be copied to the LATEX_OUTPUT output

LATEX_EXTRA_FILES =
LATEX_EXTRA_FILES =
# If the PDF_HYPERLINKS tag is set to YES, the LaTeX that is generated is
# prepared for conversion to PDF (using ps2pdf or pdflatex). The PDF file will

# default style sheet that doxygen normally uses.
# This tag requires that the tag GENERATE_RTF is set to YES.
RTF_STYLESHEET_FILE =
RTF_STYLESHEET_FILE =
# Set optional variables used in the generation of an RTF document. Syntax is
# similar to doxygen's config file. A template extensions file can be generated

RTF_EXTENSIONS_FILE =
RTF_EXTENSIONS_FILE =
# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code
# with syntax highlighting in the RTF output.

# MAN_EXTENSION with the initial . removed.
# This tag requires that the tag GENERATE_MAN is set to YES.
MAN_SUBDIR =
MAN_SUBDIR =
# If the MAN_LINKS tag is set to YES and doxygen generates man output, then it
# will generate one additional man file for each entity documented in the real

# overwrite each other's variables.
# This tag requires that the tag GENERATE_PERLMOD is set to YES.
PERLMOD_MAKEVAR_PREFIX =
PERLMOD_MAKEVAR_PREFIX =
#---------------------------------------------------------------------------
# Configuration options related to the preprocessor

# preprocessor.
# This tag requires that the tag SEARCH_INCLUDES is set to YES.
INCLUDE_PATH =
INCLUDE_PATH =
# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
# patterns (like *.h and *.hpp) to filter out the header-files in the

INCLUDE_FILE_PATTERNS =
INCLUDE_FILE_PATTERNS =
# The PREDEFINED tag can be used to specify one or more macro names that are
# defined before the preprocessor is started (similar to the -D option of e.g.

# definition found in the source code.
# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
EXPAND_AS_DEFINED =
EXPAND_AS_DEFINED =
# If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will
# remove all references to function-like macros that are alone on a line, have

# the path). If a tag file is not located in the directory in which doxygen is
# run, you must also specify the path to the tagfile here.
TAGFILES =
TAGFILES =
GENERATE_TAGFILE =
GENERATE_TAGFILE =
# If the ALLEXTERNALS tag is set to YES, all external class will be listed in
# the class index. If set to NO, only the inherited external classes will be

# interpreter (i.e. the result of 'which perl').
# The default file (with absolute path) is: /usr/bin/perl.
PERL_PATH = /usr/bin/perl
#PERL_PATH = /usr/bin/perl
#---------------------------------------------------------------------------
# Configuration options related to the dot tool

# the mscgen tool resides. If left empty the tool is assumed to be found in the
# default search path.
MSCGEN_PATH =
#MSCGEN_PATH =
# You can include diagrams made with dia in doxygen documentation. Doxygen will
# then run dia to produce the diagram and insert it in the documentation. The

DIA_PATH =
DIA_PATH =
# If set to YES the inheritance and collaboration graphs will hide inheritance
# and usage relations if the target is undocumented or is not a class.

# the path where dot can find it using this tag.
# This tag requires that the tag HAVE_DOT is set to YES.
DOT_FONTPATH =
DOT_FONTPATH =
# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for
# each documented class showing the direct and indirect inheritance relations.

# found. If left blank, it is assumed the dot tool can be found in the path.
# This tag requires that the tag HAVE_DOT is set to YES.
DOT_PATH =
DOT_PATH =
# The DOTFILE_DIRS tag can be used to specify one or more directories that
# contain dot files that are included in the documentation (see the \dotfile

DOTFILE_DIRS =
DOTFILE_DIRS =
MSCFILE_DIRS =
MSCFILE_DIRS =
DIAFILE_DIRS =
DIAFILE_DIRS =
# When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the
# path where java can find the plantuml.jar file. If left blank, it is assumed

PLANTUML_JAR_PATH =
PLANTUML_JAR_PATH =
PLANTUML_CFG_FILE =
PLANTUML_CFG_FILE =
PLANTUML_INCLUDE_PATH =
PLANTUML_INCLUDE_PATH =
# The DOT_GRAPH_MAX_NODES tag can be used to set the maximum number of nodes
# that will be shown in the graph. If the number of nodes in a graph becomes

145
gym-unity/gym_unity/envs/__init__.py


import logging
import itertools
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union, Set
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
from gym import error, spaces

self.visual_obs = None
self._n_agents = -1
self._done_agents: Set[int] = set()
self.agent_mapper = AgentIdIndexMapper()
# Save the step result from the last time all Agents requested decisions.
self._previous_step_result: BatchedStepResult = None
self._multiagent = multiagent

step_result = self._env.get_step_result(self.brain_name)
self._check_agents(step_result.n_agents())
self._previous_step_result = step_result
self.agent_mapper.set_initial_agents(list(self._previous_step_result.agent_id))
# Set observation and action spaces
if self.group_spec.is_action_discrete():

"The number of agents in the scene does not match the expected number."
)
# remove the done Agents
indices_to_keep: List[int] = []
for index, is_done in enumerate(step_result.done):
if not is_done:
indices_to_keep.append(index)
if step_result.n_agents() - sum(step_result.done) != self._n_agents:
raise UnityGymException(
"The number of agents in the scene does not match the expected number."
)
for index, agent_id in enumerate(step_result.agent_id):
if step_result.done[index]:
self.agent_mapper.mark_agent_done(agent_id, step_result.reward[index])
# Set the new AgentDone flags to True
# Note that the corresponding agent_id that gets marked done will be different

if not self._previous_step_result.contains_agent(agent_id):
step_result.done[index] = True
if agent_id in self._done_agents:
# Register this agent, and get the reward of the previous agent that
# was in its index, so that we can return it to the gym.
last_reward = self.agent_mapper.register_new_agent_id(agent_id)
self._done_agents = set()
step_result.reward[index] = last_reward
# Get a permutation of the agent IDs so that a given ID stays in the same
# index as where it was first seen.
new_id_order = self.agent_mapper.get_id_permutation(list(step_result.agent_id))
_mask.append(step_result.action_mask[mask_index][indices_to_keep])
_mask.append(step_result.action_mask[mask_index][new_id_order])
new_obs.append(step_result.obs[obs_index][indices_to_keep])
new_obs.append(step_result.obs[obs_index][new_id_order])
reward=step_result.reward[indices_to_keep],
done=step_result.done[indices_to_keep],
max_step=step_result.max_step[indices_to_keep],
agent_id=step_result.agent_id[indices_to_keep],
reward=step_result.reward[new_id_order],
done=step_result.done[new_id_order],
max_step=step_result.max_step[new_id_order],
agent_id=step_result.agent_id[new_id_order],
if self._previous_step_result.n_agents() == self._n_agents:
return action
input_index = 0
for index in range(self._previous_step_result.n_agents()):
for index, agent_id in enumerate(self._previous_step_result.agent_id):
sanitized_action[index, :] = action[input_index, :]
input_index = input_index + 1
array_index = self.agent_mapper.get_gym_index(agent_id)
sanitized_action[index, :] = action[array_index, :]
return sanitized_action
def _step(self, needs_reset: bool = False) -> BatchedStepResult:

"The environment does not have the expected amount of agents."
+ "Some agents did not request decisions at the same time."
)
self._done_agents.update(list(info.agent_id))
for agent_id, reward in zip(info.agent_id, info.reward):
self.agent_mapper.mark_agent_done(agent_id, reward)
self._env.step()
info = self._env.get_step_result(self.brain_name)
return self._sanitize_info(info)

:return: The List containing the branched actions.
"""
return self.action_lookup[action]
class AgentIdIndexMapper:
def __init__(self) -> None:
self._agent_id_to_gym_index: Dict[int, int] = {}
self._done_agents_index_to_last_reward: Dict[int, float] = {}
def set_initial_agents(self, agent_ids: List[int]) -> None:
"""
Provide the initial list of agent ids for the mapper
"""
for idx, agent_id in enumerate(agent_ids):
self._agent_id_to_gym_index[agent_id] = idx
def mark_agent_done(self, agent_id: int, reward: float) -> None:
"""
Declare the agent done with the corresponding final reward.
"""
gym_index = self._agent_id_to_gym_index.pop(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
def register_new_agent_id(self, agent_id: int) -> float:
"""
Adds the new agent ID and returns the reward to use for the previous agent in this index
"""
# Any free index is OK here.
free_index, last_reward = self._done_agents_index_to_last_reward.popitem()
self._agent_id_to_gym_index[agent_id] = free_index
return last_reward
def get_id_permutation(self, agent_ids: List[int]) -> List[int]:
"""
Get the permutation from new agent ids to the order that preserves the positions of previous agents.
The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once.
"""
# Map the new agent ids to the their index
new_agent_ids_to_index = {
agent_id: idx for idx, agent_id in enumerate(agent_ids)
}
# Make the output list. We don't write to it sequentially, so start with dummy values.
new_permutation = [-1] * len(agent_ids)
# For each agent ID, find the new index of the agent, and write it in the original index.
for agent_id, original_index in self._agent_id_to_gym_index.items():
new_permutation[original_index] = new_agent_ids_to_index[agent_id]
return new_permutation
def get_gym_index(self, agent_id: int) -> int:
"""
Get the gym index for the current agent.
"""
return self._agent_id_to_gym_index[agent_id]
class AgentIdIndexMapperSlow:
"""
Reference implementation of AgentIdIndexMapper.
The operations are O(N^2) so it shouldn't be used for large numbers of agents.
See AgentIdIndexMapper for method descriptions
"""
def __init__(self) -> None:
self._gym_id_order: List[int] = []
self._done_agents_index_to_last_reward: Dict[int, float] = {}
def set_initial_agents(self, agent_ids: List[int]) -> None:
self._gym_id_order = list(agent_ids)
def mark_agent_done(self, agent_id: int, reward: float) -> None:
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
self._gym_id_order[gym_index] = -1
def register_new_agent_id(self, agent_id: int) -> float:
original_index = self._gym_id_order.index(-1)
self._gym_id_order[original_index] = agent_id
reward = self._done_agents_index_to_last_reward.pop(original_index)
return reward
def get_id_permutation(self, agent_ids):
new_id_order = []
for agent_id in self._gym_id_order:
new_id_order.append(agent_ids.index(agent_id))
return new_id_order
def get_gym_index(self, agent_id: int) -> int:
return self._gym_id_order.index(agent_id)

77
gym-unity/gym_unity/tests/test_gym.py


import numpy as np
from gym import spaces
from gym_unity.envs import UnityEnv, UnityGymException
from gym_unity.envs import (
UnityEnv,
UnityGymException,
AgentIdIndexMapper,
AgentIdIndexMapperSlow,
)
from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult

assert isinstance(info, dict)
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_shuffled_id(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=5)
mock_step.agent_id = np.array(range(5))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=True)
shuffled_step_result = create_mock_vector_step_result(num_agents=5)
shuffled_order = [4, 2, 3, 1, 0]
shuffled_step_result.reward = np.array(shuffled_order)
shuffled_step_result.agent_id = np.array(shuffled_order)
sanitized_result = env._sanitize_info(shuffled_step_result)
for expected_reward, reward in zip(range(5), sanitized_result.reward):
assert expected_reward == reward
for expected_agent_id, agent_id in zip(range(5), sanitized_result.agent_id):
assert expected_agent_id == agent_id
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_one_agent_done(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=5)
mock_step.agent_id = np.array(range(5))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=True)
received_step_result = create_mock_vector_step_result(num_agents=6)
received_step_result.agent_id = np.array(range(6))
# agent #3 (id = 2) is Done
received_step_result.done = np.array([False] * 2 + [True] + [False] * 3)
sanitized_result = env._sanitize_info(received_step_result)
for expected_agent_id, agent_id in zip([0, 1, 5, 3, 4], sanitized_result.agent_id):
assert expected_agent_id == agent_id
# Helper methods

mock_env.return_value.get_agent_groups.return_value = ["MockBrain"]
mock_env.return_value.get_agent_group_spec.return_value = mock_spec
mock_env.return_value.get_step_result.return_value = mock_result
@pytest.mark.parametrize("mapper_cls", [AgentIdIndexMapper, AgentIdIndexMapperSlow])
def test_agent_id_index_mapper(mapper_cls):
mapper = mapper_cls()
initial_agent_ids = [1001, 1002, 1003, 1004]
mapper.set_initial_agents(initial_agent_ids)
# Mark some agents as done with their last rewards.
mapper.mark_agent_done(1001, 42.0)
mapper.mark_agent_done(1004, 1337.0)
# Now add new agents, and get the rewards of the agent they replaced.
old_reward1 = mapper.register_new_agent_id(2001)
old_reward2 = mapper.register_new_agent_id(2002)
# Order of the rewards don't matter
assert {old_reward1, old_reward2} == {42.0, 1337.0}
new_agent_ids = [1002, 1003, 2001, 2002]
permutation = mapper.get_id_permutation(new_agent_ids)
# Make sure it's actually a permutation - needs to contain 0..N-1 with no repeats.
assert set(permutation) == set(range(0, 4))
# For initial agents that were in the initial group, they need to be in the same slot.
# Agents that were added later can appear in any free slot.
permuted_ids = [new_agent_ids[i] for i in permutation]
for idx, agent_id in enumerate(initial_agent_ids):
if agent_id in permuted_ids:
assert permuted_ids[idx] == agent_id

62
ml-agents-envs/mlagents_envs/environment.py


SINGLE_BRAIN_ACTION_TYPES = SCALAR_ACTION_TYPES + (list, np.ndarray)
API_VERSION = "API-15-dev0"
DEFAULT_EDITOR_PORT = 5004
PORT_COMMAND_LINE_ARG = "--mlagents-port"
def __init__(
self,

def get_communicator(worker_id, base_port, timeout_wait):
return RpcCommunicator(worker_id, base_port, timeout_wait)
def executable_launcher(self, file_name, docker_training, no_graphics, args):
cwd = os.getcwd()
file_name = (
file_name.strip()
@staticmethod
def validate_environment_path(env_path: str) -> Optional[str]:
# Strip out executable extensions if passed
env_path = (
env_path.strip()
true_filename = os.path.basename(os.path.normpath(file_name))
true_filename = os.path.basename(os.path.normpath(env_path))
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None
cwd = os.getcwd()
true_filename = os.path.basename(os.path.normpath(env_path))
candidates = glob.glob(os.path.join(cwd, file_name) + ".x86_64")
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
candidates = glob.glob(os.path.join(cwd, file_name) + ".x86")
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
candidates = glob.glob(file_name + ".x86_64")
candidates = glob.glob(env_path + ".x86_64")
candidates = glob.glob(file_name + ".x86")
candidates = glob.glob(env_path + ".x86")
os.path.join(
cwd, file_name + ".app", "Contents", "MacOS", true_filename
)
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
os.path.join(file_name + ".app", "Contents", "MacOS", true_filename)
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
os.path.join(cwd, file_name + ".app", "Contents", "MacOS", "*")
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
os.path.join(file_name + ".app", "Contents", "MacOS", "*")
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
candidates = glob.glob(os.path.join(cwd, file_name + ".exe"))
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
candidates = glob.glob(file_name + ".exe")
candidates = glob.glob(env_path + ".exe")
return launch_string
def executable_launcher(self, file_name, docker_training, no_graphics, args):
launch_string = self.validate_environment_path(file_name)
"Couldn't launch the {0} environment. "
"Provided filename does not match any environments.".format(
true_filename
)
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)
else:
logger.debug("This is the launch string {}".format(launch_string))

if no_graphics:
subprocess_args += ["-nographics", "-batchmode"]
subprocess_args += ["--port", str(self.port)]
subprocess_args += [
UnityEnvironment.PORT_COMMAND_LINE_ARG,
str(self.port),
]
subprocess_args += args
try:
self.proc1 = subprocess.Popen(

# we created with `xvfb`.
#
docker_ls = (
"exec xvfb-run --auto-servernum"
" --server-args='-screen 0 640x480x24'"
" {0} --port {1}"
).format(launch_string, str(self.port))
f"exec xvfb-run --auto-servernum --server-args='-screen 0 640x480x24'"
f" {launch_string} {UnityEnvironment.PORT_COMMAND_LINE_ARG} {self.port}"
)
self.proc1 = subprocess.Popen(
docker_ls,
stdout=subprocess.PIPE,

82
ml-agents/mlagents/trainers/common/nn_policy.py


is_training: bool,
load: bool,
tanh_squash: bool = False,
resample: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
create_tf_graph: bool = True,
):

:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output.
:param resample: Whether we are using the resampling trick to update the policy in continuous output.
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output.
"""
super().__init__(seed, brain, trainer_params, load)
self.grads = None

trainer_params.get("vis_encode_type", "simple")
)
self.tanh_squash = tanh_squash
self.resample = resample
self.reparameterize = reparameterize
self.condition_sigma_on_obs = condition_sigma_on_obs
self.trainable_variables: List[tf.Variable] = []

return
self.create_input_placeholders()
encoded = self._create_encoder(
self.visual_in,
self.processed_vector_in,
self.h_size,
self.num_layers,
self.vis_encode_type,
)
self.h_size,
self.num_layers,
self.vis_encode_type,
encoded,
self.resample,
self.reparameterize,
self._create_dc_actor(
self.h_size, self.num_layers, self.vis_encode_type
)
self._create_dc_actor(encoded)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)

run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
def _create_cc_actor(
def _create_encoder(
visual_in: List[tf.Tensor],
vector_in: tf.Tensor,
tanh_squash: bool = False,
resample: bool = False,
condition_sigma_on_obs: bool = True,
) -> None:
) -> tf.Tensor:
Creates Continuous control actor-critic model.
Creates an encoder for visual and vector observations.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param resample: Whether we are using the resampling trick to update the policy.
:return: The hidden layer (tf.Tensor) after the encoder.
hidden_stream = ModelUtils.create_observation_streams(
encoded = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,

)[0]
return encoded
def _create_cc_actor(
self,
encoded: tf.Tensor,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy.
"""
hidden_stream,
self.memory_in,
self.sequence_length_ph,
name="lstm_policy",
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy"
hidden_policy = hidden_stream
hidden_policy = encoded
with tf.variable_scope("policy"):
mu = tf.layers.dense(

sampled_policy = mu + sigma * epsilon
# Stop gradient if we're not doing the resampling trick
if not resample:
if not reparameterize:
sampled_policy_probs = tf.stop_gradient(sampled_policy)
else:
sampled_policy_probs = sampled_policy

(tf.identity(self.all_log_probs)), axis=1, keepdims=True
)
def _create_dc_actor(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
def _create_dc_actor(self, encoded: tf.Tensor) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.

with tf.variable_scope("policy"):
hidden_stream = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"

],
axis=1,
)
hidden_policy = tf.concat([hidden_stream, prev_action_oh], axis=1)
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"

self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
else:
hidden_policy = hidden_stream
hidden_policy = encoded
policy_branches = []
with tf.variable_scope("policy"):

6
ml-agents/mlagents/trainers/common/tf_optimizer.py


# We do this in a separate step to feed the memory outs - a further optimization would
# be to append to the obs before running sess.run.
final_value_estimates = self.get_value_estimates(
final_value_estimates = self._get_value_estimates(
def get_value_estimates(
def _get_value_estimates(
self,
next_obs: List[np.ndarray],
done: bool,

self.update_dict.update(self.reward_signals[reward_signal].update_dict)
def create_optimizer_op(
self, learning_rate: float, name: str = "Adam"
self, learning_rate: tf.Tensor, name: str = "Adam"
) -> tf.train.Optimizer:
return tf.train.AdamOptimizer(learning_rate=learning_rate, name=name)

14
ml-agents/mlagents/trainers/learn.py


from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
def _create_parser():

env_args: Optional[List[str]],
) -> Callable[[int, List[SideChannel]], BaseEnv]:
if env_path is not None:
# Strip out executable extensions if passed
env_path = (
env_path.strip()
.replace(".app", "")
.replace(".exe", "")
.replace(".x86_64", "")
.replace(".x86", "")
)
launch_string = UnityEnvironment.validate_environment_path(env_path)
if launch_string is None:
raise UnityEnvironmentException(
f"Couldn't launch the {env_path} environment. Provided filename does not match any environments."
)
docker_training = docker_target_name is not None
if docker_training and env_path is not None:
# Comments for future maintenance:

2
ml-agents/mlagents/trainers/sac/trainer.py


self.is_training,
self.load,
tanh_squash=True,
resample=True,
reparameterize=True,
create_tf_graph=False,
)
for _reward_signal in policy.reward_signals.keys():

13
ml-agents/mlagents/trainers/tests/test_learn.py


from mlagents.trainers import learn
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.learn import parse_command_line
from mlagents_envs.exception import UnityEnvironmentException
def basic_options(extra_args=None):

mock_init.assert_called_once()
assert mock_init.call_args[0][1] == "/dockertarget/models/ppo"
assert mock_init.call_args[0][2] == "/dockertarget/summaries"
def test_bad_env_path():
with pytest.raises(UnityEnvironmentException):
learn.create_environment_factory(
env_path="/foo/bar",
docker_target_name=None,
no_graphics=True,
seed=None,
start_port=8000,
env_args=None,
)
@patch("builtins.open", new_callable=mock_open, read_data="{}")

2
com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs.meta


fileFormatVersion: 2
guid: b651f66c75a1646c6ab48de06d0e13ef
guid: a5e0cbcbc514b473399c262dd37541ea
MonoImporter:
externalObjects: {}
serializedVersion: 2

2
com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta


fileFormatVersion: 2
guid: 50f710d360a49461cad67ff5e6bcefe1
guid: f2902496c0120472b90269f94a0aec7e
MonoImporter:
externalObjects: {}
serializedVersion: 2

14
com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta


fileFormatVersion: 2
guid: a79c7ccb2cd042b5b1e710b9588d921b
timeCreated: 1537388072
fileFormatVersion: 2
guid: ebaf7878a8cc74ee3aae07daf9e1b6f2
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

4
com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs


/// Used for imitation learning, or other forms of learning from data.
/// </summary>
[Serializable]
public class Demonstration : ScriptableObject
internal class Demonstration : ScriptableObject
{
public DemonstrationMetaData metaData;
public BrainParameters brainParameters;

/// Kept in a struct for easy serialization and deserialization.
/// </summary>
[Serializable]
public class DemonstrationMetaData
internal class DemonstrationMetaData
{
public int numberExperiences;
public int numberEpisodes;

104
com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs


using System.IO;
using System.IO.Abstractions;
using Google.Protobuf;
using System.Collections.Generic;

/// Responsible for writing demonstration data to file.
/// Responsible for writing demonstration data to stream (usually a file stream).
public class DemonstrationStore
public class DemonstrationWriter
readonly IFileSystem m_FileSystem;
const string k_DemoDirectory = "Assets/Demonstrations/";
const string k_ExtensionType = ".demo";
string m_FilePath;
public DemonstrationStore(IFileSystem fileSystem)
{
if (fileSystem != null)
{
m_FileSystem = fileSystem;
}
else
{
m_FileSystem = new FileSystem();
}
}
/// Initializes the Demonstration Store, and writes initial data.
/// Create a DemonstrationWriter that will write to the specified stream.
/// The stream must support writes and seeking.
public void Initialize(
string demonstrationName, BrainParameters brainParameters, string brainName)
/// <param name="stream"></param>
public DemonstrationWriter(Stream stream)
CreateDirectory();
CreateDemonstrationFile(demonstrationName);
WriteBrainParameters(brainName, brainParameters);
m_Writer = stream;
/// Checks for the existence of the Demonstrations directory
/// and creates it if it does not exist.
/// Writes the initial data to the stream.
void CreateDirectory()
public void Initialize(
string demonstrationName, BrainParameters brainParameters, string brainName)
if (!m_FileSystem.Directory.Exists(k_DemoDirectory))
if (m_Writer == null)
m_FileSystem.Directory.CreateDirectory(k_DemoDirectory);
// Already closed
return;
m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
var metaProto = m_MetaData.ToProto();
metaProto.WriteDelimitedTo(m_Writer);
WriteBrainParameters(brainName, brainParameters);
/// Creates demonstration file.
/// Writes meta-data. Note that this is called at the *end* of recording, but writes to the
/// beginning of the file.
void CreateDemonstrationFile(string demonstrationName)
void WriteMetadata()
// Creates demonstration file.
var literalName = demonstrationName;
m_FilePath = k_DemoDirectory + literalName + k_ExtensionType;
var uniqueNameCounter = 0;
while (m_FileSystem.File.Exists(m_FilePath))
if (m_Writer == null)
literalName = demonstrationName + "_" + uniqueNameCounter;
m_FilePath = k_DemoDirectory + literalName + k_ExtensionType;
uniqueNameCounter++;
// Already closed
return;
m_Writer = m_FileSystem.File.Create(m_FilePath);
m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
var metaProtoBytes = metaProto.ToByteArray();
m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length);
m_Writer.Seek(0, 0);
metaProto.WriteDelimitedTo(m_Writer);
}

void WriteBrainParameters(string brainName, BrainParameters brainParameters)
{
if (m_Writer == null)
{
// Already closed
return;
}
// Writes BrainParameters to file.
m_Writer.Seek(MetaDataBytes + 1, 0);
var brainProto = brainParameters.ToProto(brainName, false);

/// <summary>
/// Write AgentInfo experience to file.
/// </summary>
public void Record(AgentInfo info, List<ISensor> sensors)
internal void Record(AgentInfo info, List<ISensor> sensors)
if (m_Writer == null)
{
// Already closed
return;
}
// Increment meta-data counters.
m_MetaData.numberExperiences++;
m_CumulativeReward += info.reward;

agentProto.WriteDelimitedTo(m_Writer);
}
if (m_Writer == null)
{
// Already closed
return;
}
m_Writer = null;
}
/// <summary>

{
m_MetaData.numberEpisodes += 1;
}
/// <summary>
/// Writes meta-data.
/// </summary>
void WriteMetadata()
{
var metaProto = m_MetaData.ToProto();
var metaProtoBytes = metaProto.ToByteArray();
m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length);
m_Writer.Seek(0, 0);
metaProto.WriteDelimitedTo(m_Writer);
}
}
}

8
com.unity.ml-agents/Runtime/Demonstrations.meta


fileFormatVersion: 2
guid: 85e02c21d231b4f5fa0c5f87e5f907a2
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

179
com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs


using System.IO.Abstractions;
using System.Text.RegularExpressions;
using UnityEngine;
using System.IO;
namespace MLAgents
{
/// <summary>
/// Demonstration Recorder Component.
/// </summary>
[RequireComponent(typeof(Agent))]
[AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)]
public class DemonstrationRecorder : MonoBehaviour
{
[Tooltip("Whether or not to record demonstrations.")]
public bool record;
[Tooltip("Base demonstration file name. Will have numbers appended to make unique.")]
public string demonstrationName;
[Tooltip("Base directory to write the demo files. If null, will use {Application.dataPath}/Demonstrations.")]
public string demonstrationDirectory;
DemonstrationWriter m_DemoWriter;
internal const int MaxNameLength = 16;
const string k_ExtensionType = ".demo";
IFileSystem m_FileSystem;
Agent m_Agent;
void OnEnable()
{
m_Agent = GetComponent<Agent>();
}
void Update()
{
if (record)
{
LazyInitialize();
}
}
/// <summary>
/// Creates demonstration store for use in recording.
/// Has no effect if the demonstration store was already created.
/// </summary>
internal DemonstrationWriter LazyInitialize(IFileSystem fileSystem = null)
{
if (m_DemoWriter != null)
{
return m_DemoWriter;
}
if (m_Agent == null)
{
m_Agent = GetComponent<Agent>();
}
m_FileSystem = fileSystem ?? new FileSystem();
var behaviorParams = GetComponent<BehaviorParameters>();
if (string.IsNullOrEmpty(demonstrationName))
{
demonstrationName = behaviorParams.behaviorName;
}
if (string.IsNullOrEmpty(demonstrationDirectory))
{
demonstrationDirectory = Path.Combine(Application.dataPath, "Demonstrations");
}
demonstrationName = SanitizeName(demonstrationName, MaxNameLength);
var filePath = MakeDemonstrationFilePath(m_FileSystem, demonstrationDirectory, demonstrationName);
var stream = m_FileSystem.File.Create(filePath);
m_DemoWriter = new DemonstrationWriter(stream);
m_DemoWriter.Initialize(
demonstrationName,
behaviorParams.brainParameters,
behaviorParams.fullyQualifiedBehaviorName
);
AddDemonstrationWriterToAgent(m_DemoWriter);
return m_DemoWriter;
}
/// <summary>
/// Removes all characters except alphanumerics from demonstration name.
/// Shorten name if it is longer than the maxNameLength.
/// </summary>
internal static string SanitizeName(string demoName, int maxNameLength)
{
var rgx = new Regex("[^a-zA-Z0-9 -]");
demoName = rgx.Replace(demoName, "");
// If the string is too long, it will overflow the metadata.
if (demoName.Length > maxNameLength)
{
demoName = demoName.Substring(0, maxNameLength);
}
return demoName;
}
/// <summary>
/// Gets a unique path for the demonstrationName in the demonstrationDirectory.
/// </summary>
/// <param name="fileSystem"></param>
/// <param name="demonstrationDirectory"></param>
/// <param name="demonstrationName"></param>
/// <returns></returns>
internal static string MakeDemonstrationFilePath(
IFileSystem fileSystem, string demonstrationDirectory, string demonstrationName
)
{
// Create the directory if it doesn't already exist
if (!fileSystem.Directory.Exists(demonstrationDirectory))
{
fileSystem.Directory.CreateDirectory(demonstrationDirectory);
}
var literalName = demonstrationName;
var filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType);
var uniqueNameCounter = 0;
while (fileSystem.File.Exists(filePath))
{
// TODO should we use a timestamp instead of a counter here? This loops an increasing number of times
// as the number of demos increases.
literalName = demonstrationName + "_" + uniqueNameCounter;
filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType);
uniqueNameCounter++;
}
return filePath;
}
/// <summary>
/// Close the DemonstrationWriter and remove it from the Agent.
/// Has no effect if the DemonstrationWriter is already closed (or wasn't opened)
/// </summary>
public void Close()
{
if (m_DemoWriter != null)
{
RemoveDemonstrationWriterFromAgent(m_DemoWriter);
m_DemoWriter.Close();
m_DemoWriter = null;
}
}
/// <summary>
/// Clean up the DemonstrationWriter when shutting down or destroying the Agent.
/// </summary>
void OnDestroy()
{
Close();
}
/// <summary>
/// Add additional DemonstrationWriter to the Agent. It is still up to the user to Close this
/// DemonstrationWriters when recording is done.
/// </summary>
/// <param name="demoWriter"></param>
public void AddDemonstrationWriterToAgent(DemonstrationWriter demoWriter)
{
m_Agent.DemonstrationWriters.Add(demoWriter);
}
/// <summary>
/// Remove additional DemonstrationWriter to the Agent. It is still up to the user to Close this
/// DemonstrationWriters when recording is done.
/// </summary>
/// <param name="demoWriter"></param>
public void RemoveDemonstrationWriterFromAgent(DemonstrationWriter demoWriter)
{
m_Agent.DemonstrationWriters.Remove(demoWriter);
}
}
}

95
com.unity.ml-agents/Runtime/DemonstrationRecorder.cs


using System.IO.Abstractions;
using System.Text.RegularExpressions;
using UnityEngine;
using System.Collections.Generic;
namespace MLAgents
{
/// <summary>
/// Demonstration Recorder Component.
/// </summary>
[RequireComponent(typeof(Agent))]
[AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)]
public class DemonstrationRecorder : MonoBehaviour
{
public bool record;
public string demonstrationName;
string m_FilePath;
DemonstrationStore m_DemoStore;
public const int MaxNameLength = 16;
void Start()
{
if (Application.isEditor && record)
{
InitializeDemoStore();
}
}
void Update()
{
if (Application.isEditor && record && m_DemoStore == null)
{
InitializeDemoStore();
}
}
/// <summary>
/// Creates demonstration store for use in recording.
/// </summary>
public void InitializeDemoStore(IFileSystem fileSystem = null)
{
m_DemoStore = new DemonstrationStore(fileSystem);
var behaviorParams = GetComponent<BehaviorParameters>();
demonstrationName = SanitizeName(demonstrationName, MaxNameLength);
m_DemoStore.Initialize(
demonstrationName,
behaviorParams.brainParameters,
behaviorParams.fullyQualifiedBehaviorName);
}
/// <summary>
/// Removes all characters except alphanumerics from demonstration name.
/// Shorten name if it is longer than the maxNameLength.
/// </summary>
public static string SanitizeName(string demoName, int maxNameLength)
{
var rgx = new Regex("[^a-zA-Z0-9 -]");
demoName = rgx.Replace(demoName, "");
// If the string is too long, it will overflow the metadata.
if (demoName.Length > maxNameLength)
{
demoName = demoName.Substring(0, maxNameLength);
}
return demoName;
}
/// <summary>
/// Forwards AgentInfo to Demonstration Store.
/// </summary>
public void WriteExperience(AgentInfo info, List<ISensor> sensors)
{
m_DemoStore?.Record(info, sensors);
}
public void Close()
{
if (m_DemoStore != null)
{
m_DemoStore.Close();
m_DemoStore = null;
}
}
/// <summary>
/// Closes Demonstration store.
/// </summary>
void OnApplicationQuit()
{
if (Application.isEditor && record)
{
Close();
}
}
}
}

/com.unity.ml-agents/Runtime/Demonstration.cs.meta → /com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs.meta

/com.unity.ml-agents/Runtime/DemonstrationRecorder.cs.meta → /com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta

/com.unity.ml-agents/Runtime/DemonstrationStore.cs.meta → /com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta

/com.unity.ml-agents/Runtime/Demonstration.cs → /com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs

/com.unity.ml-agents/Runtime/DemonstrationStore.cs → /com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs

正在加载...
取消
保存