浏览代码

resolving merge conflicts

/release-0.14.0
Anupam Bhatnagar 5 年前
当前提交
d8c79f48
共有 84 个文件被更改,包括 558 次插入407 次删除
  1. 10
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  2. 8
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
  3. 4
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
  4. 6
      Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
  5. 30
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  6. 11
      Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
  7. 14
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  8. 4
      Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
  9. 6
      Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
  10. 24
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
  11. 8
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs
  12. 2
      Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs
  13. 20
      Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  14. 30
      Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
  15. 6
      Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
  16. 2
      com.unity.ml-agents/Editor/AgentEditor.cs
  17. 3
      com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
  18. 2
      com.unity.ml-agents/Editor/BrainParametersDrawer.cs
  19. 2
      com.unity.ml-agents/Editor/DemonstrationDrawer.cs
  20. 2
      com.unity.ml-agents/Editor/DemonstrationImporter.cs
  21. 61
      com.unity.ml-agents/Runtime/Academy.cs
  22. 48
      com.unity.ml-agents/Runtime/ActionMasker.cs
  23. 195
      com.unity.ml-agents/Runtime/Agent.cs
  24. 4
      com.unity.ml-agents/Runtime/DemonstrationRecorder.cs
  25. 1
      com.unity.ml-agents/Runtime/DemonstrationStore.cs
  26. 1
      com.unity.ml-agents/Runtime/Grpc/GrpcExtensions.cs
  27. 13
      com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs
  28. 7
      com.unity.ml-agents/Runtime/ICommunicator.cs
  29. 1
      com.unity.ml-agents/Runtime/InferenceBrain/BarracudaModelParamLoader.cs
  30. 1
      com.unity.ml-agents/Runtime/InferenceBrain/GeneratorImpl.cs
  31. 1
      com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs
  32. 1
      com.unity.ml-agents/Runtime/InferenceBrain/TensorGenerator.cs
  33. 2
      com.unity.ml-agents/Runtime/InferenceBrain/Utils/Multinomial.cs
  34. 2
      com.unity.ml-agents/Runtime/InferenceBrain/Utils/RandomNormal.cs
  35. 1
      com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs
  36. 1
      com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs
  37. 1
      com.unity.ml-agents/Runtime/Policy/IPolicy.cs
  38. 1
      com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs
  39. 2
      com.unity.ml-agents/Runtime/Sensor/CameraSensor.cs
  40. 2
      com.unity.ml-agents/Runtime/Sensor/CameraSensorComponent.cs
  41. 2
      com.unity.ml-agents/Runtime/Sensor/ISensor.cs
  42. 2
      com.unity.ml-agents/Runtime/Sensor/Observation.cs
  43. 2
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs
  44. 2
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs
  45. 2
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs
  46. 2
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs
  47. 2
      com.unity.ml-agents/Runtime/Sensor/RenderTextureSensor.cs
  48. 2
      com.unity.ml-agents/Runtime/Sensor/RenderTextureSensorComponent.cs
  49. 2
      com.unity.ml-agents/Runtime/Sensor/SensorBase.cs
  50. 2
      com.unity.ml-agents/Runtime/Sensor/SensorComponent.cs
  51. 2
      com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs
  52. 2
      com.unity.ml-agents/Runtime/Sensor/StackingSensor.cs
  53. 2
      com.unity.ml-agents/Runtime/Sensor/VectorSensor.cs
  54. 2
      com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs
  55. 1
      com.unity.ml-agents/Runtime/Utilities.cs
  56. 9
      com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs
  57. 47
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  58. 1
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  59. 1
      com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs
  60. 1
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  61. 1
      com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
  62. 2
      com.unity.ml-agents/Tests/Editor/Sensor/WriterAdapterTests.cs
  63. 8
      docs/Getting-Started-with-Balance-Ball.md
  64. 4
      docs/Learning-Environment-Best-Practices.md
  65. 20
      docs/Learning-Environment-Create-New.md
  66. 44
      docs/Learning-Environment-Design-Agents.md
  67. 8
      docs/Learning-Environment-Design.md
  68. 4
      docs/Limitations.md
  69. 12
      docs/Migrating.md
  70. 184
      docs/Python-API.md
  71. 3
      docs/Reward-Signals.md
  72. 2
      docs/Training-Generalized-Reinforcement-Learning-Agents.md
  73. 2
      docs/Training-Self-Play.md
  74. 3
      docs/Unity-Inference-Engine.md
  75. 2
      gym-unity/gym_unity/__init__.py
  76. 2
      ml-agents-envs/mlagents_envs/__init__.py
  77. 3
      ml-agents-envs/mlagents_envs/environment.py
  78. 2
      ml-agents/mlagents/trainers/__init__.py
  79. 2
      ml-agents/mlagents/trainers/learn.py
  80. 2
      notebooks/getting-started.ipynb
  81. 4
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/Monitor.cs
  82. 11
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/Monitor.cs.meta
  83. 12
      com.unity.ml-agents/Runtime/Monitor.cs.meta
  84. 0
      /Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/Monitor.cs

10
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


SetResetParameters();
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(gameObject.transform.rotation.z);
AddVectorObs(gameObject.transform.rotation.x);
AddVectorObs(ball.transform.position - gameObject.transform.position);
AddVectorObs(m_BallRb.velocity);
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
sensor.AddObservation(m_BallRb.velocity);
}
public override void AgentAction(float[] vectorAction)

8
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs


SetResetParameters();
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(gameObject.transform.rotation.z);
AddVectorObs(gameObject.transform.rotation.x);
AddVectorObs((ball.transform.position - gameObject.transform.position));
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation((ball.transform.position - gameObject.transform.position));
}
public override void AgentAction(float[] vectorAction)

4
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs


{
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(m_Position, 20);
sensor.AddOneHotObservation(m_Position, 20);
}
public override void AgentAction(float[] vectorAction)

6
Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs


SetResetParameters();
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(gameObject.transform.localPosition);
AddVectorObs(target.transform.localPosition);
sensor.AddObservation(gameObject.transform.localPosition);
sensor.AddObservation(target.transform.localPosition);
}
public override void AgentAction(float[] vectorAction)

30
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp)
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
AddVectorObs(velocityRelativeToLookRotationToTarget);
sensor.AddObservation(velocityRelativeToLookRotationToTarget);
AddVectorObs(angularVelocityRelativeToLookRotationToTarget);
sensor.AddObservation(angularVelocityRelativeToLookRotationToTarget);
AddVectorObs(localPosRelToBody);
AddVectorObs(bp.currentXNormalizedRot); // Current x rot
AddVectorObs(bp.currentYNormalizedRot); // Current y rot
AddVectorObs(bp.currentZNormalizedRot); // Current z rot
AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit);
sensor.AddObservation(localPosRelToBody);
sensor.AddObservation(bp.currentXNormalizedRot); // Current x rot
sensor.AddObservation(bp.currentYNormalizedRot); // Current y rot
sensor.AddObservation(bp.currentZNormalizedRot); // Current z rot
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
m_JdController.GetCurrentJointForces();

RaycastHit hit;
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f))
{
AddVectorObs(hit.distance);
sensor.AddObservation(hit.distance);
AddVectorObs(10.0f);
sensor.AddObservation(10.0f);
AddVectorObs(bodyForwardRelativeToLookRotationToTarget);
sensor.AddObservation(bodyForwardRelativeToLookRotationToTarget);
AddVectorObs(bodyUpRelativeToLookRotationToTarget);
sensor.AddObservation(bodyUpRelativeToLookRotationToTarget);
CollectObservationBodyPart(bodyPart);
CollectObservationBodyPart(bodyPart, sensor);
}
}

11
Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs


{
base.InitializeAgent();
m_AgentRb = GetComponent<Rigidbody>();
Monitor.verticalOffset = 1f;
m_MyArea = area.GetComponent<FoodCollectorArea>();
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(localVelocity.x);
AddVectorObs(localVelocity.z);
AddVectorObs(System.Convert.ToInt32(m_Frozen));
AddVectorObs(System.Convert.ToInt32(m_Shoot));
sensor.AddObservation(localVelocity.x);
sensor.AddObservation(localVelocity.z);
sensor.AddObservation(System.Convert.ToInt32(m_Frozen));
sensor.AddObservation(System.Convert.ToInt32(m_Shoot));
}
}

14
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


{
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
{
// There are no numeric observations to collect as this environment uses visual
// observations.

{
SetMask();
SetMask(actionMasker);
}
}

void SetMask()
void SetMask(ActionMasker actionMasker)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;

if (positionX == 0)
{
SetActionMask(k_Left);
actionMasker.SetActionMask(k_Left);
SetActionMask(k_Right);
actionMasker.SetActionMask(k_Right);
SetActionMask(k_Down);
actionMasker.SetActionMask(k_Down);
SetActionMask(k_Up);
actionMasker.SetActionMask(k_Up);
}
}

4
Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs


m_GroundMaterial = m_GroundRenderer.material;
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(GetStepCount() / (float)maxStep);
sensor.AddObservation(GetStepCount() / (float)maxStep);
}
}

6
Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs


m_SwitchLogic = areaSwitch.GetComponent<PyramidSwitch>();
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(m_SwitchLogic.GetState());
AddVectorObs(transform.InverseTransformDirection(m_AgentRb.velocity));
sensor.AddObservation(m_SwitchLogic.GetState());
sensor.AddObservation(transform.InverseTransformDirection(m_AgentRb.velocity));
}
}

24
Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs


/// We collect the normalized rotations, angularal velocities, and velocities of both
/// limbs of the reacher as well as the relative position of the target and hand.
/// </summary>
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(pendulumA.transform.localPosition);
AddVectorObs(pendulumA.transform.rotation);
AddVectorObs(m_RbA.angularVelocity);
AddVectorObs(m_RbA.velocity);
sensor.AddObservation(pendulumA.transform.localPosition);
sensor.AddObservation(pendulumA.transform.rotation);
sensor.AddObservation(m_RbA.angularVelocity);
sensor.AddObservation(m_RbA.velocity);
AddVectorObs(pendulumB.transform.localPosition);
AddVectorObs(pendulumB.transform.rotation);
AddVectorObs(m_RbB.angularVelocity);
AddVectorObs(m_RbB.velocity);
sensor.AddObservation(pendulumB.transform.localPosition);
sensor.AddObservation(pendulumB.transform.rotation);
sensor.AddObservation(m_RbB.angularVelocity);
sensor.AddObservation(m_RbB.velocity);
AddVectorObs(goal.transform.localPosition);
AddVectorObs(hand.transform.localPosition);
sensor.AddObservation(goal.transform.localPosition);
sensor.AddObservation(hand.transform.localPosition);
AddVectorObs(m_GoalSpeed);
sensor.AddObservation(m_GoalSpeed);
}
/// <summary>

8
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs


public class ProjectSettingsOverrides : MonoBehaviour
{
// Original values
float m_OriginalMonitorVerticalOffset;
Vector3 m_OriginalGravity;
float m_OriginalFixedDeltaTime;
float m_OriginalMaximumDeltaTime;

[Tooltip("Increase or decrease the scene gravity. Use ~3x to make things less floaty")]
public float gravityMultiplier = 1.0f;
[Header("Display Settings")]
public float monitorVerticalOffset;
[Header("Advanced physics settings")]
[Tooltip("The interval in seconds at which physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) are performed.")]
public float fixedDeltaTime = .02f;

public void Awake()
{
// Save the original values
m_OriginalMonitorVerticalOffset = Monitor.verticalOffset;
m_OriginalGravity = Physics.gravity;
m_OriginalFixedDeltaTime = Time.fixedDeltaTime;
m_OriginalMaximumDeltaTime = Time.maximumDeltaTime;

// Override
Monitor.verticalOffset = monitorVerticalOffset;
Physics.gravity *= gravityMultiplier;
Time.fixedDeltaTime = fixedDeltaTime;
Time.maximumDeltaTime = maximumDeltaTime;

public void OnDestroy()
{
Monitor.verticalOffset = m_OriginalMonitorVerticalOffset;
Physics.gravity = m_OriginalGravity;
Time.fixedDeltaTime = m_OriginalFixedDeltaTime;
Time.maximumDeltaTime = m_OriginalMaximumDeltaTime;

2
Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs


public class TemplateAgent : Agent
{
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
}

20
Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


SetResetParameters();
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(m_InvertMult * (transform.position.x - myArea.transform.position.x));
AddVectorObs(transform.position.y - myArea.transform.position.y);
AddVectorObs(m_InvertMult * m_AgentRb.velocity.x);
AddVectorObs(m_AgentRb.velocity.y);
sensor.AddObservation(m_InvertMult * (transform.position.x - myArea.transform.position.x));
sensor.AddObservation(transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_AgentRb.velocity.x);
sensor.AddObservation(m_AgentRb.velocity.y);
AddVectorObs(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
AddVectorObs(ball.transform.position.y - myArea.transform.position.y);
AddVectorObs(m_InvertMult * m_BallRb.velocity.x);
AddVectorObs(m_BallRb.velocity.y);
sensor.AddObservation(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
sensor.AddObservation(ball.transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_BallRb.velocity.x);
sensor.AddObservation(m_BallRb.velocity.y);
AddVectorObs(m_InvertMult * gameObject.transform.rotation.z);
sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z);
}
public override void AgentAction(float[] vectorAction)

30
Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp)
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Is this bp touching the ground
AddVectorObs(rb.velocity);
AddVectorObs(rb.angularVelocity);
sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Is this bp touching the ground
sensor.AddObservation(rb.velocity);
sensor.AddObservation(rb.angularVelocity);
AddVectorObs(localPosRelToHips);
sensor.AddObservation(localPosRelToHips);
AddVectorObs(bp.currentXNormalizedRot);
AddVectorObs(bp.currentYNormalizedRot);
AddVectorObs(bp.currentZNormalizedRot);
AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit);
sensor.AddObservation(bp.currentXNormalizedRot);
sensor.AddObservation(bp.currentYNormalizedRot);
sensor.AddObservation(bp.currentZNormalizedRot);
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(m_DirToTarget.normalized);
AddVectorObs(m_JdController.bodyPartsDict[hips].rb.position);
AddVectorObs(hips.forward);
AddVectorObs(hips.up);
sensor.AddObservation(m_DirToTarget.normalized);
sensor.AddObservation(m_JdController.bodyPartsDict[hips].rb.position);
sensor.AddObservation(hips.forward);
sensor.AddObservation(hips.up);
CollectObservationBodyPart(bodyPart);
CollectObservationBodyPart(bodyPart, sensor);
}
}

6
Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs


}
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(agentPos / 20f);
AddVectorObs(DoGroundCheck(true) ? 1 : 0);
sensor.AddObservation(agentPos / 20f);
sensor.AddObservation(DoGroundCheck(true) ? 1 : 0);
}
/// <summary>

2
com.unity.ml-agents/Editor/AgentEditor.cs


*/
[CustomEditor(typeof(Agent), true)]
[CanEditMultipleObjects]
public class AgentEditor : Editor
internal class AgentEditor : Editor
{
public override void OnInspectorGUI()
{

3
com.unity.ml-agents/Editor/BehaviorParametersEditor.cs


using UnityEngine;
using UnityEditor;
using Barracuda;
using MLAgents.Sensor;
namespace MLAgents
{

[CustomEditor(typeof(BehaviorParameters))]
[CanEditMultipleObjects]
public class BehaviorParametersEditor : Editor
internal class BehaviorParametersEditor : Editor
{
const float k_TimeBetweenModelReloads = 2f;
// Time since the last reload of the model

2
com.unity.ml-agents/Editor/BrainParametersDrawer.cs


/// Inspector.
/// </summary>
[CustomPropertyDrawer(typeof(BrainParameters))]
public class BrainParametersDrawer : PropertyDrawer
internal class BrainParametersDrawer : PropertyDrawer
{
// The height of a line in the Unity Inspectors
const float k_LineHeight = 17f;

2
com.unity.ml-agents/Editor/DemonstrationDrawer.cs


/// </summary>
[CustomEditor(typeof(Demonstration))]
[CanEditMultipleObjects]
public class DemonstrationEditor : Editor
internal class DemonstrationEditor : Editor
{
SerializedProperty m_BrainParameters;
SerializedProperty m_DemoMetaData;

2
com.unity.ml-agents/Editor/DemonstrationImporter.cs


/// Asset Importer used to parse demonstration files.
/// </summary>
[ScriptedImporter(1, new[] {"demo"})]
public class DemonstrationImporter : ScriptedImporter
internal class DemonstrationImporter : ScriptedImporter
{
const string k_IconPath = "Assets/ML-Agents/Resources/DemoIcon.png";

61
com.unity.ml-agents/Runtime/Academy.cs


"docs/Learning-Environment-Design.md")]
public class Academy : IDisposable
{
const string k_ApiVersion = "API-14";
const string k_ApiVersion = "API-15-dev0";
const int k_EditorTrainingPort = 5004;
// Lazy initializer pattern, see https://csharpindepth.com/articles/singleton#lazy

{
Application.quitting += Dispose;
LazyInitialization();
LazyInitialize();
}
/// <summary>

internal void LazyInitialization()
internal void LazyInitialize()
{
if (!m_Initialized)
{

/// Enable stepping of the Academy during the FixedUpdate phase. This is done by creating a temporary
/// GameObject with a MonoBehavior that calls Academy.EnvironmentStep().
/// </summary>
public void EnableAutomaticStepping()
void EnableAutomaticStepping()
{
if (m_FixedUpdateStepper != null)
{

}
/// <summary>
/// Registers SideChannel to the Academy to send and receive data with Python.
/// If IsCommunicatorOn is false, the SideChannel will not be registered.
/// </summary>
/// <param name="sideChannel"> The side channel to be registered.</param>
public void RegisterSideChannel(SideChannel channel)
{
LazyInitialize();
Communicator?.RegisterSideChannel(channel);
}
/// <summary>
/// Unregisters SideChannel to the Academy. If the side channel was not registered,
/// nothing will happen.
/// </summary>
/// <param name="sideChannel"> The side channel to be unregistered.</param>
public void UnregisterSideChannel(SideChannel channel)
{
Communicator?.UnregisterSideChannel(channel);
}
/// <summary>
public void DisableAutomaticStepping(bool destroyImmediate = false)
void DisableAutomaticStepping()
{
if (m_FixedUpdateStepper == null)
{

m_FixedUpdateStepper = null;
if (destroyImmediate)
if (Application.isEditor)
{
UnityEngine.Object.DestroyImmediate(m_StepperObject);
}

}
/// <summary>
/// Returns whether or not the Academy is automatically stepped during the FixedUpdate phase.
/// Determines whether or not the Academy is automatically stepped during the FixedUpdate phase.
public bool IsAutomaticSteppingEnabled
public bool AutomaticSteppingEnabled
set {
if (value)
{
EnableAutomaticStepping();
}
else
{
DisableAutomaticStepping();
}
}
}
// Used to read Python-provided environment parameters

/// <returns>
/// Current episode number.
/// </returns>
public int GetEpisodeCount()
public int EpisodeCount
return m_EpisodeCount;
get { return m_EpisodeCount; }
}
/// <summary>

/// Current step count.
/// </returns>
public int GetStepCount()
public int StepCount
return m_StepCount;
get { return m_StepCount; }
}
/// <summary>

/// Total step count.
/// </returns>
public int GetTotalStepCount()
public int TotalStepCount
return m_TotalStepCount;
get { return m_TotalStepCount; }
}
/// <summary>

/// </summary>
public void Dispose()
{
DisableAutomaticStepping(true);
DisableAutomaticStepping();
// Signal to listeners that the academy is being destroyed now
DestroyAction?.Invoke();

48
com.unity.ml-agents/Runtime/ActionMasker.cs


namespace MLAgents
{
internal class ActionMasker
public class ActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.

}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision.
/// The actionIndices correspond to the actions the agent will be unable to perform
/// on the branch 0.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
public void SetActionMask(IEnumerable<int> actionIndices)
{
SetActionMask(0, actionIndices);
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision for the specified
/// action branch. The actionIndex correspond to the action the agent will be unable
/// to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndex">The index of the masked action</param>
public void SetActionMask(int branch, int actionIndex)
{
SetActionMask(branch, new[] { actionIndex });
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. The actionIndex
/// correspond to the action the agent will be unable to perform on the branch 0.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
public void SetActionMask(int actionIndex)
{
SetActionMask(0, new[] { actionIndex });
}
/// <summary>
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// able to perform the actions passed as argument at the next decision for the specified
/// action branch. The actionIndices correspond to the action options the agent will
/// be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>

/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
public bool[] GetMask()
internal bool[] GetMask()
{
if (m_CurrentMask != null)
{

/// <summary>
/// Resets the current mask for an agent
/// </summary>
public void ResetMask()
internal void ResetMask()
{
if (m_CurrentMask != null)
{

195
com.unity.ml-agents/Runtime/Agent.cs


using System.Collections.Generic;
using UnityEngine;
using Barracuda;
using MLAgents.Sensor;
using UnityEngine.Serialization;
namespace MLAgents

UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations();
CollectObservations(collectObservationsSensor, m_ActionMasker);
}
m_Info.actionMasks = m_ActionMasker.GetMask();

}
/// <summary>
/// Collects the (vector, visual) observations of the agent.
/// Collects the vector observations of the agent.
/// Simply, an agents observation is any environment information that helps
/// the Agent acheive its goal. For example, for a fighting Agent, its
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its
/// Vector observations are added by calling the provided helper methods:
/// - <see cref="AddVectorObs(int)"/>
/// - <see cref="AddVectorObs(float)"/>
/// - <see cref="AddVectorObs(Vector3)"/>
/// - <see cref="AddVectorObs(Vector2)"/>
/// - <see>
/// <cref>AddVectorObs(float[])</cref>
/// </see>
/// - <see>
/// <cref>AddVectorObs(List{float})</cref>
/// </see>
/// - <see cref="AddVectorObs(Quaternion)"/>
/// - <see cref="AddVectorObs(bool)"/>
/// - <see cref="AddVectorObs(int, int)"/>
/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="AddObservation(int)"/>
/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation

/// </remarks>
public virtual void CollectObservations()
{
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
protected void SetActionMask(IEnumerable<int> actionIndices)
{
m_ActionMasker.SetActionMask(0, actionIndices);
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
protected void SetActionMask(int actionIndex)
{
m_ActionMasker.SetActionMask(0, new[] { actionIndex });
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndex">The index of the masked action</param>
protected void SetActionMask(int branch, int actionIndex)
{
m_ActionMasker.SetActionMask(branch, new[] { actionIndex });
}
/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
protected void SetActionMask(int branch, IEnumerable<int> actionIndices)
{
m_ActionMasker.SetActionMask(branch, actionIndices);
}
/// <summary>
/// Adds a float observation to the vector observations of the agent.
/// Increases the size of the agents vector observation by 1.
/// </summary>
/// <param name="observation">Observation.</param>
protected void AddVectorObs(float observation)
{
collectObservationsSensor.AddObservation(observation);
}
/// <summary>
/// Adds an integer observation to the vector observations of the agent.
/// Increases the size of the agents vector observation by 1.
/// </summary>
/// <param name="observation">Observation.</param>
protected void AddVectorObs(int observation)
public virtual void CollectObservations(VectorSensor sensor)
collectObservationsSensor.AddObservation(observation);
/// Adds an Vector3 observation to the vector observations of the agent.
/// Increases the size of the agents vector observation by 3.
/// Collects the vector observations of the agent.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// <param name="observation">Observation.</param>
protected void AddVectorObs(Vector3 observation)
/// <remarks>
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="AddObservation(int)"/>
/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it. You can call the following method on the ActionMasker
/// input :
/// - <see cref="SetActionMask(int branch, IEnumerable<int> actionIndices)"/>
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// - <see cref="SetActionMask(IEnumerable<int> actionIndices)"/>
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// The branch input is the index of the action, actionIndices are the indices of the
/// invalid options for that action.
/// </remarks>
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
collectObservationsSensor.AddObservation(observation);
}
/// <summary>
/// Adds an Vector2 observation to the vector observations of the agent.
/// Increases the size of the agents vector observation by 2.
/// </summary>
/// <param name="observation">Observation.</param>
protected void AddVectorObs(Vector2 observation)
{
collectObservationsSensor.AddObservation(observation);
}
/// <summary>
/// Adds a collection of float observations to the vector observations of the agent.
/// Increases the size of the agents vector observation by size of the collection.
/// </summary>
/// <param name="observation">Observation.</param>
protected void AddVectorObs(IEnumerable<float> observation)
{
collectObservationsSensor.AddObservation(observation);
}
/// <summary>
/// Adds a quaternion observation to the vector observations of the agent.
/// Increases the size of the agents vector observation by 4.
/// </summary>
/// <param name="observation">Observation.</param>
protected void AddVectorObs(Quaternion observation)
{
collectObservationsSensor.AddObservation(observation);
}
/// <summary>
/// Adds a boolean observation to the vector observation of the agent.
/// Increases the size of the agent's vector observation by 1.
/// </summary>
/// <param name="observation"></param>
protected void AddVectorObs(bool observation)
{
collectObservationsSensor.AddObservation(observation);
}
protected void AddVectorObs(int observation, int range)
{
collectObservationsSensor.AddOneHotObservation(observation, range);
CollectObservations(sensor);
}
/// <summary>

ResetData();
m_StepCount = 0;
AgentReset();
}
internal void UpdateAgentAction(AgentAction action)
{
m_Action = action;
}
/// <summary>

4
com.unity.ml-agents/Runtime/DemonstrationRecorder.cs


using System.Text.RegularExpressions;
using UnityEngine;
using System.Collections.Generic;
using MLAgents.Sensor;
namespace MLAgents
{

{
public bool record;
public string demonstrationName;
Agent m_RecordingAgent;
string m_FilePath;
DemonstrationStore m_DemoStore;
public const int MaxNameLength = 16;

/// </summary>
public void InitializeDemoStore(IFileSystem fileSystem = null)
{
m_RecordingAgent = GetComponent<Agent>();
m_DemoStore = new DemonstrationStore(fileSystem);
var behaviorParams = GetComponent<BehaviorParameters>();
demonstrationName = SanitizeName(demonstrationName, MaxNameLength);

behaviorParams.fullyQualifiedBehaviorName);
Monitor.Log("Recording Demonstration of Agent: ", m_RecordingAgent.name);
}
/// <summary>

1
com.unity.ml-agents/Runtime/DemonstrationStore.cs


using System.IO.Abstractions;
using Google.Protobuf;
using System.Collections.Generic;
using MLAgents.Sensor;
namespace MLAgents
{

1
com.unity.ml-agents/Runtime/Grpc/GrpcExtensions.cs


using System.Linq;
using Google.Protobuf;
using MLAgents.CommunicatorObjects;
using MLAgents.Sensor;
using UnityEngine;
using System.Runtime.CompilerServices;

13
com.unity.ml-agents/Runtime/Grpc/RpcCommunicator.cs


using MLAgents.CommunicatorObjects;
using System.IO;
using Google.Protobuf;
using MLAgents.Sensor;
namespace MLAgents
{

"side channels of the same type.", channelType));
}
m_SideChannels.Add(channelType, sideChannel);
}
/// <summary>
/// Unregisters a side channel from the communicator.
/// </summary>
/// <param name="sideChannel"> The side channel to be unregistered.</param>
public void UnregisterSideChannel(SideChannel sideChannel)
{
if (m_SideChannels.ContainsKey(sideChannel.ChannelType()))
{
m_SideChannels.Remove(sideChannel.ChannelType());
}
}
/// <summary>

7
com.unity.ml-agents/Runtime/ICommunicator.cs


using System;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
namespace MLAgents
{

/// </summary>
/// <param name="sideChannel"> The side channel to be registered.</param>
void RegisterSideChannel(SideChannel sideChannel);
/// <summary>
/// Unregisters a side channel from the communicator.
/// </summary>
/// <param name="sideChannel"> The side channel to be unregistered.</param>
void UnregisterSideChannel(SideChannel sideChannel);
}
}

1
com.unity.ml-agents/Runtime/InferenceBrain/BarracudaModelParamLoader.cs


using System.Collections.Generic;
using System.Linq;
using Barracuda;
using MLAgents.Sensor;
using UnityEngine;
namespace MLAgents.InferenceBrain

1
com.unity.ml-agents/Runtime/InferenceBrain/GeneratorImpl.cs


using System;
using Barracuda;
using MLAgents.InferenceBrain.Utils;
using MLAgents.Sensor;
using UnityEngine;
namespace MLAgents.InferenceBrain

1
com.unity.ml-agents/Runtime/InferenceBrain/ModelRunner.cs


using Barracuda;
using UnityEngine.Profiling;
using System;
using MLAgents.Sensor;
namespace MLAgents.InferenceBrain
{

1
com.unity.ml-agents/Runtime/InferenceBrain/TensorGenerator.cs


using System.Collections.Generic;
using Barracuda;
using MLAgents.Sensor;
namespace MLAgents.InferenceBrain
{

2
com.unity.ml-agents/Runtime/InferenceBrain/Utils/Multinomial.cs


/// entry[i] = P(x \le i), NOT P(i - 1 \le x \lt i).
/// (\le stands for less than or equal to while \lt is strictly less than).
/// </summary>
public class Multinomial
internal class Multinomial
{
readonly System.Random m_Random;

2
com.unity.ml-agents/Runtime/InferenceBrain/Utils/RandomNormal.cs


/// https://en.wikipedia.org/wiki/Marsaglia_polar_method
/// TODO: worth overriding System.Random instead of aggregating?
/// </summary>
public class RandomNormal
internal class RandomNormal
{
readonly double m_Mean;
readonly double m_Stddev;

1
com.unity.ml-agents/Runtime/Policy/BarracudaPolicy.cs


using System.Collections.Generic;
using MLAgents.InferenceBrain;
using System;
using MLAgents.Sensor;
namespace MLAgents
{

1
com.unity.ml-agents/Runtime/Policy/HeuristicPolicy.cs


using MLAgents.Sensor;
using System.Collections.Generic;
using System;

1
com.unity.ml-agents/Runtime/Policy/IPolicy.cs


using System;
using System.Collections.Generic;
using MLAgents.Sensor;
namespace MLAgents
{

1
com.unity.ml-agents/Runtime/Policy/RemotePolicy.cs


using UnityEngine;
using System.Collections.Generic;
using MLAgents.Sensor;
using System;
namespace MLAgents

2
com.unity.ml-agents/Runtime/Sensor/CameraSensor.cs


using System;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
public class CameraSensor : ISensor
{

2
com.unity.ml-agents/Runtime/Sensor/CameraSensorComponent.cs


using System;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
[AddComponentMenu("ML Agents/Camera Sensor", (int)MenuGroup.Sensors)]
public class CameraSensorComponent : SensorComponent

2
com.unity.ml-agents/Runtime/Sensor/ISensor.cs


namespace MLAgents.Sensor
namespace MLAgents
{
public enum SensorCompressionType
{

2
com.unity.ml-agents/Runtime/Sensor/Observation.cs


using System;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
internal struct Observation
{

2
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs


using System.Collections.Generic;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
public class RayPerceptionSensor : ISensor
{

2
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs


using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
[AddComponentMenu("ML Agents/Ray Perception Sensor 2D", (int)MenuGroup.Sensors)]
public class RayPerceptionSensorComponent2D : RayPerceptionSensorComponentBase

2
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs


using System;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
[AddComponentMenu("ML Agents/Ray Perception Sensor 3D", (int)MenuGroup.Sensors)]
public class RayPerceptionSensorComponent3D : RayPerceptionSensorComponentBase

2
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs


using System.Collections.Generic;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
public abstract class RayPerceptionSensorComponentBase : SensorComponent
{

2
com.unity.ml-agents/Runtime/Sensor/RenderTextureSensor.cs


using System;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
public class RenderTextureSensor : ISensor
{

2
com.unity.ml-agents/Runtime/Sensor/RenderTextureSensorComponent.cs


using System;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
[AddComponentMenu("ML Agents/Render Texture Sensor", (int)MenuGroup.Sensors)]
public class RenderTextureSensorComponent : SensorComponent

2
com.unity.ml-agents/Runtime/Sensor/SensorBase.cs


using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
public abstract class SensorBase : ISensor
{

2
com.unity.ml-agents/Runtime/Sensor/SensorComponent.cs


using System;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
/// <summary>
/// Editor components for creating Sensors. Generally an ISensor implementation should have a corresponding

2
com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs


using System.Collections.Generic;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
internal class SensorShapeValidator
{

2
com.unity.ml-agents/Runtime/Sensor/StackingSensor.cs


namespace MLAgents.Sensor
namespace MLAgents
{
/// <summary>
/// Sensor that wraps around another Sensor to provide temporal stacking.

2
com.unity.ml-agents/Runtime/Sensor/VectorSensor.cs


using System.Collections.Generic;
using UnityEngine;
namespace MLAgents.Sensor
namespace MLAgents
{
public class VectorSensor : ISensor
{

2
com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs


using Barracuda;
using MLAgents.InferenceBrain;
namespace MLAgents.Sensor
namespace MLAgents
{
/// <summary>
/// Allows sensors to write to both TensorProxy and float arrays/lists.

1
com.unity.ml-agents/Runtime/Utilities.cs


using UnityEngine;
using System.Collections.Generic;
using MLAgents.Sensor;
namespace MLAgents
{

9
com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs


using System.IO.Abstractions.TestingHelpers;
using System.Reflection;
using MLAgents.CommunicatorObjects;
using MLAgents.Sensor;
namespace MLAgents.Tests
{

public class ObservationAgent : TestAgent
{
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(1f);
AddVectorObs(2f);
AddVectorObs(3f);
sensor.AddObservation(1f);
sensor.AddObservation(2f);
sensor.AddObservation(3f);
}
}

47
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


using UnityEngine;
using NUnit.Framework;
using System.Reflection;
using MLAgents.Sensor;
using System.Collections.Generic;
namespace MLAgents.Tests

sensors.Add(sensor1);
}
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
AddVectorObs(0f);
sensor.AddObservation(0f);
}
public override void AgentAction(float[] vectorAction)

{
var aca = Academy.Instance;
Assert.AreNotEqual(null, aca);
Assert.AreEqual(0, aca.GetEpisodeCount());
Assert.AreEqual(0, aca.GetStepCount());
Assert.AreEqual(0, aca.GetTotalStepCount());
Assert.AreEqual(0, aca.EpisodeCount);
Assert.AreEqual(0, aca.StepCount);
Assert.AreEqual(0, aca.TotalStepCount);
}
[Test]