浏览代码

Merge branch 'develop-add-fire-mm3' into develop-add-fire-checkpoint

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
d57aa9ab
共有 123 个文件被更改,包括 8191 次插入3207 次删除
  1. 20
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/DirectionIndicator.cs
  2. 977
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity
  3. 2
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity.meta
  4. 962
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerStatic.unity
  5. 161
      Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
  6. 1001
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn
  7. 2
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn.meta
  8. 1001
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn
  9. 2
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn.meta
  10. 4
      com.unity.ml-agents.extensions/Editor/Unity.ML-Agents.Extensions.Editor.asmdef
  11. 1
      com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
  12. 29
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  13. 8
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
  14. 54
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  15. 117
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  16. 80
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  17. 68
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  18. 88
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
  19. 67
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
  20. 5
      com.unity.ml-agents/CHANGELOG.md
  21. 33
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  22. 2
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  23. 51
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  24. 72
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  25. 2
      com.unity.ml-agents/Runtime/Actuators/IActuator.cs
  26. 2
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  27. 2
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
  28. 201
      com.unity.ml-agents/Runtime/Agent.cs
  29. 14
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  30. 2
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  31. 6
      com.unity.ml-agents/Runtime/DecisionRequester.cs
  32. 118
      com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
  33. 17
      com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
  34. 37
      com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
  35. 20
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  36. 3
      com.unity.ml-agents/Runtime/Policies/IPolicy.cs
  37. 15
      com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
  38. 46
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  39. 3
      com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
  40. 4
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  41. 2
      config/ppo/WalkerDynamic.yaml
  42. 2
      config/ppo/WalkerStatic.yaml
  43. 31
      docs/Learning-Environment-Examples.md
  44. 3
      ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
  45. 25
      ml-agents/mlagents/trainers/policy/tf_policy.py
  46. 1
      ml-agents/mlagents/trainers/policy/torch_policy.py
  47. 8
      ml-agents/mlagents/trainers/ppo/optimizer_tf.py
  48. 20
      ml-agents/mlagents/trainers/ppo/trainer.py
  49. 19
      ml-agents/mlagents/trainers/sac/trainer.py
  50. 6
      ml-agents/mlagents/trainers/saver/tf_saver.py
  51. 2
      ml-agents/mlagents/trainers/settings.py
  52. 41
      ml-agents/mlagents/trainers/stats.py
  53. 3
      ml-agents/mlagents/trainers/tests/mock_brain.py
  54. 62
      ml-agents/mlagents/trainers/tests/test_env_param_manager.py
  55. 114
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  56. 2
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  57. 1
      ml-agents/mlagents/trainers/tests/test_sac.py
  58. 4
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  59. 20
      ml-agents/mlagents/trainers/tf/model_serialization.py
  60. 26
      ml-agents/mlagents/trainers/tf/models.py
  61. 10
      ml-agents/mlagents/trainers/trainer_controller.py
  62. 21
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Targets/DynamicTarget.prefab
  63. 19
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Targets/StaticTarget.prefab
  64. 82
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Ragdoll/WalkerRagdollBase.prefab
  65. 523
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/PlatformDynamicTarget.prefab
  66. 7
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/PlatformDynamicTarget.prefab.meta
  67. 8
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Targets.meta
  68. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerDy.demo.meta
  69. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerDyVS.demo.meta
  70. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerStVS.demo.meta
  71. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerSta.demo.meta
  72. 8
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Platforms.meta
  73. 8
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Ragdoll.meta
  74. 1001
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamicVariableSpeed.unity
  75. 7
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamicVariableSpeed.unity.meta
  76. 1001
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerStaticVariableSpeed.unity
  77. 9
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerStaticVariableSpeed.unity.meta
  78. 1001
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamicVariableSpeed.nn
  79. 11
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamicVariableSpeed.nn.meta
  80. 1001
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStaticVariableSpeed.nn
  81. 11
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStaticVariableSpeed.nn.meta
  82. 66
      com.unity.ml-agents.extensions/Editor/RigidBodySensorComponentEditor.cs
  83. 11
      com.unity.ml-agents.extensions/Editor/RigidBodySensorComponentEditor.cs.meta
  84. 38
      com.unity.ml-agents/Runtime/Agent.deprecated.cs
  85. 3
      com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta
  86. 26
      config/ppo/WalkerDynamicVariableSpeed.yaml
  87. 26
      config/ppo/WalkerStaticVariableSpeed.yaml
  88. 13
      ml-agents/mlagents/tf_utils/globals.py
  89. 157
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Platforms/PlatformWalkerDynamicSingleSpeed.prefab
  90. 7
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Platforms/PlatformWalkerDynamicSingleSpeed.prefab.meta
  91. 298
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Platforms/PlatformWalkerDynamicVariableSpeed.prefab
  92. 7
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Platforms/PlatformWalkerDynamicVariableSpeed.prefab.meta
  93. 287
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Ragdoll/WalkerRagdollDySingleSpeedVariant.prefab

20
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/DirectionIndicator.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine;
public bool updatedByAgent; //should this be updated by the agent? If not, it will use local settings
void OnEnable()
{
m_StartingYPos = transform.position.y;

{
transform.position = new Vector3(transformToFollow.position.x, m_StartingYPos + heightOffset, transformToFollow.position.z);
if (updatedByAgent)
return;
transform.position = new Vector3(transformToFollow.position.x, m_StartingYPos + heightOffset,
transformToFollow.position.z);
}
//Public method to allow an agent to directly update this component
public void MatchOrientation(Transform t)
{
transform.position = new Vector3(t.position.x, m_StartingYPos + heightOffset, t.position.z);
transform.rotation = t.rotation;
}
}
}

977
Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity
文件差异内容过多而无法显示
查看文件

2
Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity.meta


fileFormatVersion: 2
guid: 79d5d2687bfbe45f5b78bd6c04992e0d
guid: 65c87f50b8c81433d8fd7f6550773467
DefaultImporter:
externalObjects: {}
userData:

962
Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerStatic.unity
文件差异内容过多而无法显示
查看文件

161
Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


using System;
using MLAgentsExamples;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgentsExamples;

public class WalkerAgent : Agent
{
public float maximumWalkingSpeed = 999; //The max walk velocity magnitude an agent will be rewarded for
Vector3 m_WalkDir; //Direction to the target
// Quaternion m_WalkDirLookRot; //Will hold the rotation to our target
[Header("Walk Speed")]
[Range(0.1f, 10)]
[SerializeField]
//The walking speed to try and achieve
private float m_TargetWalkingSpeed = 10;
public float MTargetWalkingSpeed // property
{
get { return m_TargetWalkingSpeed; }
set { m_TargetWalkingSpeed = Mathf.Clamp(value, .1f, m_maxWalkingSpeed); }
}
const float m_maxWalkingSpeed = 10; //The max walking speed
//Should the agent sample a new goal velocity each episode?
//If true, walkSpeed will be randomly set between zero and m_maxWalkingSpeed in OnEpisodeBegin()
//If false, the goal velocity will be walkingSpeed
public bool randomizeWalkSpeedEachEpisode;
//The direction an agent will walk during training.
private Vector3 m_WorldDirToWalk = Vector3.right;
[Header("Target To Walk Towards")] [Space(10)]
public TargetController target; //Target the agent will walk towards.
[Header("Target To Walk Towards")] public Transform target; //Target the agent will walk towards during training.
[Header("Body Parts")] [Space(10)] public Transform hips;
[Header("Body Parts")] public Transform hips;
public Transform chest;
public Transform spine;
public Transform head;

public Transform forearmR;
public Transform handR;
[Header("Orientation")] [Space(10)]
public OrientationCubeController orientationCube;
OrientationCubeController m_OrientationCube;
//The indicator graphic gameobject that points towards the target
DirectionIndicator m_DirectionIndicator;
orientationCube.UpdateOrientation(hips, target.transform);
m_OrientationCube = GetComponentInChildren<OrientationCubeController>();
m_DirectionIndicator = GetComponentInChildren<DirectionIndicator>();
//Setup each body part
m_JdController = GetComponent<JointDriveController>();

}
//Random start rotation to help generalize
transform.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);
hips.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);
UpdateOrientationObjects();
orientationCube.UpdateOrientation(hips, target.transform);
//Set our goal walking speed
MTargetWalkingSpeed =
randomizeWalkSpeedEachEpisode ? Random.Range(0.1f, m_maxWalkingSpeed) : MTargetWalkingSpeed;
SetResetParameters();
}

//Get velocities in the context of our orientation cube's space
//Note: You can get these velocities in world space as well but it may not train as well.
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.velocity));
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.velocity));
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.position - hips.position));
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.position - hips.position));
if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR)
{

/// </summary>
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(Quaternion.FromToRotation(hips.forward, orientationCube.transform.forward));
sensor.AddObservation(Quaternion.FromToRotation(head.forward, orientationCube.transform.forward));
var cubeForward = m_OrientationCube.transform.forward;
sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position));
//velocity we want to match
var velGoal = cubeForward * MTargetWalkingSpeed;
//ragdoll's avg vel
var avgVel = GetAvgVelocity();
//current ragdoll velocity. normalized
sensor.AddObservation(Vector3.Distance(velGoal, avgVel));
//avg body vel relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel));
//vel goal relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal));
//rotation deltas
sensor.AddObservation(Quaternion.FromToRotation(hips.forward, cubeForward));
sensor.AddObservation(Quaternion.FromToRotation(head.forward, cubeForward));
//Position of target position relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(target.transform.position));
foreach (var bodyPart in m_JdController.bodyPartsList)
{

bpDict[forearmR].SetJointStrength(vectorAction[++i]);
}
//Update OrientationCube and DirectionIndicator
void UpdateOrientationObjects()
{
m_WorldDirToWalk = target.position - hips.position;
m_OrientationCube.UpdateOrientation(hips, target);
if (m_DirectionIndicator)
{
m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform);
}
}
var cubeForward = orientationCube.transform.forward;
orientationCube.UpdateOrientation(hips, target.transform);
UpdateOrientationObjects();
var cubeForward = m_OrientationCube.transform.forward;
// a. Velocity alignment with goal direction.
var moveTowardsTargetReward = Vector3.Dot(cubeForward,
Vector3.ClampMagnitude(m_JdController.bodyPartsDict[hips].rb.velocity, maximumWalkingSpeed));
if (float.IsNaN(moveTowardsTargetReward))
// a. Match target speed
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
var matchSpeedReward = GetMatchingVelocityReward(cubeForward * MTargetWalkingSpeed, GetAvgVelocity());
//Check for NaNs
if (float.IsNaN(matchSpeedReward))
$" cubeForward: {cubeForward}\n"+
$" hips.velocity: {m_JdController.bodyPartsDict[hips].rb.velocity}\n"+
$" maximumWalkingSpeed: {maximumWalkingSpeed}"
$" cubeForward: {cubeForward}\n" +
$" hips.velocity: {m_JdController.bodyPartsDict[hips].rb.velocity}\n" +
$" maximumWalkingSpeed: {m_maxWalkingSpeed}"
// b. Rotation alignment with goal direction.
var lookAtTargetReward = Vector3.Dot(cubeForward, head.forward);
// b. Rotation alignment with target direction.
//This reward will approach 1 if it faces the target direction perfectly and approach zero as it deviates
var lookAtTargetReward = (Vector3.Dot(cubeForward, head.forward) + 1) * .5F;
//Check for NaNs
$" cubeForward: {cubeForward}\n"+
$" cubeForward: {cubeForward}\n" +
// c. Encourage head height. //Should normalize to ~1
var headHeightOverFeetReward =
((head.position.y - footL.position.y) + (head.position.y - footR.position.y) / 10);
if (float.IsNaN(headHeightOverFeetReward))
AddReward(matchSpeedReward * lookAtTargetReward);
}
//Returns the average velocity of all of the body parts
//Using the velocity of the hips only has shown to result in more erratic movement from the limbs, so...
//...using the average helps prevent this erratic movement
Vector3 GetAvgVelocity()
{
Vector3 velSum = Vector3.zero;
Vector3 avgVel = Vector3.zero;
//ALL RBS
int numOfRB = 0;
foreach (var item in m_JdController.bodyPartsList)
throw new ArgumentException(
"NaN in headHeightOverFeetReward.\n" +
$" head.position: {head.position}\n"+
$" footL.position: {footL.position}\n"+
$" footR.position: {footR.position}"
);
numOfRB++;
velSum += item.rb.velocity;
AddReward(
+ 0.02f * moveTowardsTargetReward
+ 0.02f * lookAtTargetReward
+ 0.005f * headHeightOverFeetReward
);
avgVel = velSum / numOfRB;
return avgVel;
}
//normalized value of the difference in avg speed vs goal walking speed.
public float GetMatchingVelocityReward(Vector3 velocityGoal, Vector3 actualVelocity)
{
//distance between our actual velocity and goal velocity
var velDeltaMagnitude = Mathf.Clamp(Vector3.Distance(actualVelocity, velocityGoal), 0, MTargetWalkingSpeed);
//return the value on a declining sigmoid shaped curve that decays from 1 to 0
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
return Mathf.Pow(1 - Mathf.Pow(velDeltaMagnitude / MTargetWalkingSpeed, 2), 2);
}
/// <summary>

1001
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn
文件差异内容过多而无法显示
查看文件

2
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn.meta


fileFormatVersion: 2
guid: e785133c5b0ac461588106642550d1b3
guid: 8cbae6de45ea44d0c97366e252052722
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj

1001
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn
文件差异内容过多而无法显示
查看文件

2
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn.meta


fileFormatVersion: 2
guid: 8dfd4337ed40e4d48872a4f86919c9da
guid: 185990f76b7804d1e83378e9d4454c6b
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj

4
com.unity.ml-agents.extensions/Editor/Unity.ML-Agents.Extensions.Editor.asmdef


{
"name": "Unity.ML-Agents.Extensions.Editor",
"references": [
"Unity.ML-Agents.Extensions"
"Unity.ML-Agents.Extensions",
"Unity.ML-Agents",
"Unity.ML-Agents.Editor"
],
"includePlatforms": [
"Editor"

1
com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs


using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.Editor")]

29
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


return new Pose { rotation = t.rotation, position = t.position };
}
/// <inheritdoc/>
protected internal override Object GetObjectAt(int index)
{
return m_Bodies[index];
}
internal IEnumerable<ArticulationBody> GetEnabledArticulationBodies()
{
if (m_Bodies == null)
{
yield break;
}
for (var i = 0; i < m_Bodies.Length; i++)
{
var articBody = m_Bodies[i];
if (articBody == null)
{
// Ignore a virtual root.
continue;
}
if (IsPoseEnabled(i))
{
yield return articBody;
}
}
}
}
}
#endif // UNITY_2020_1_OR_NEWER

8
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs


var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
numJointObservations += ArticulationBodyJointExtractor.NumObservations(
poseExtractor.Bodies[i], Settings
);
numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}

54
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

string m_SensorName;
PoseExtractor m_PoseExtractor;
IJointExtractor[] m_JointExtractors;
List<IJointExtractor> m_JointExtractors;
/// Construct a new PhysicsBodySensor
/// Construct a new PhysicsBodySensor
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
/// <param name="poseExtractor"></param>
Rigidbody rootBody,
GameObject rootGameObject,
GameObject virtualRoot,
RigidBodyPoseExtractor poseExtractor,
string sensorName=null
string sensorName
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_SensorName = sensorName;
var rigidBodies = poseExtractor.Bodies;
if (rigidBodies != null)
{
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
for (var i = 1; i < rigidBodies.Length; i++)
{
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var rb in poseExtractor.GetEnabledRigidbodies())
m_JointExtractors = new IJointExtractor[0];
var jointExtractor = new RigidBodyJointExtractor(rb);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);

m_Settings = settings;
var numJointExtractorObservations = 0;
var articBodies = poseExtractor.Bodies;
if (articBodies != null)
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
for (var i = 1; i < articBodies.Length; i++)
{
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
var jointExtractor = new ArticulationBodyJointExtractor(articBody);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);

117
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


using System;
using Object = UnityEngine.Object;
namespace Unity.MLAgents.Extensions.Sensors
{

{
if (m_ParentIndices == null)
{
return -1;
throw new NullReferenceException("No parent indices set");
}
return m_ParentIndices[index];

public void SetPoseEnabled(int index, bool val)
{
m_PoseEnabled[index] = val;
}
public bool IsPoseEnabled(int index)
{
return m_PoseEnabled[index];
}
/// <summary>

/// <returns></returns>
protected internal abstract Vector3 GetLinearVelocityAt(int index);
/// <summary>
/// Return the underlying object at the given index. This is only
/// used for display in the inspector.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected internal virtual Object GetObjectAt(int index)
{
return null;
}
/// <summary>
/// Update the internal model space transform storage based on the underlying system.

Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue);
}
}
/// <summary>
/// Simplified representation of the a node in the hierarchy for display.
/// </summary>
internal struct DisplayNode
{
/// <summary>
/// Underlying object in the hierarchy. Pass to EditorGUIUtility.ObjectContent() for display.
/// </summary>
public Object NodeObject;
/// <summary>
/// Whether the poses for the object are enabled.
/// </summary>
public bool Enabled;
/// <summary>
/// Depth in the hierarchy, used for adjusting the indent level.
/// </summary>
public int Depth;
/// <summary>
/// The index of the corresponding object in the PoseExtractor.
/// </summary>
public int OriginalIndex;
}
/// <summary>
/// Get a list of display nodes in depth-first order.
/// </summary>
/// <returns></returns>
internal IList<DisplayNode> GetDisplayNodes()
{
if (NumPoses == 0)
{
return Array.Empty<DisplayNode>();
}
var nodesOut = new List<DisplayNode>(NumPoses);
// List of children for each node
var tree = new Dictionary<int, List<int>>();
for (var i = 0; i < NumPoses; i++)
{
var parent = GetParentIndex(i);
if (i == -1)
{
continue;
}
if (!tree.ContainsKey(parent))
{
tree[parent] = new List<int>();
}
tree[parent].Add(i);
}
// Store (index, depth) in the stack
var stack = new Stack<(int, int)>();
stack.Push((0, 0));
while (stack.Count != 0)
{
var (current, depth) = stack.Pop();
var obj = GetObjectAt(current);
var node = new DisplayNode
{
NodeObject = obj,
Enabled = IsPoseEnabled(current),
OriginalIndex = current,
Depth = depth
};
nodesOut.Add(node);
// Add children
if (tree.ContainsKey(current))
{
// Push to the stack in reverse order
var children = tree[current];
for (var childIdx = children.Count-1; childIdx >= 0; childIdx--)
{
stack.Push((children[childIdx], depth+1));
}
}
// Safety check
// This shouldn't even happen, but in case we have a cycle in the graph
// exit instead of looping forever and eating up all the memory.
if (nodesOut.Count > NumPoses)
{
return nodesOut;
}
}
return nodesOut;
}
}
/// <summary>

80
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs


/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
/// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
/// a stabilized refernece frame, which can improve learning.</param>
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, GameObject virtualRoot = null)
/// a stabilized reference frame, which can improve learning.</param>
/// <param name="enableBodyPoses">Optional mapping of whether a body's psoe should be enabled or not.</param>
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null,
GameObject virtualRoot = null, Dictionary<Rigidbody, bool> enableBodyPoses = null)
{
if (rootBody == null)
{

Rigidbody[] rbs;
Joint[] joints;
joints = rootBody.GetComponentsInChildren <Joint>();
joints = rootGameObject.GetComponentsInChildren<Joint>();
}
if (rbs == null || rbs.Length == 0)

}
if (rbs[0] != rootBody)
if (rbs[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;

}
}
var joints = rootBody.GetComponentsInChildren <Joint>();
foreach (var j in joints)
{
var parent = j.connectedBody;

// By default, ignore the root
SetPoseEnabled(0, false);
if (enableBodyPoses != null)
{
foreach (var pair in enableBodyPoses)
{
var rb = pair.Key;
if (bodyToIndex.TryGetValue(rb, out var index))
{
SetPoseEnabled(index, pair.Value);
}
}
}
}
/// <inheritdoc/>

return new Pose { rotation = body.rotation, position = body.position };
}
/// <inheritdoc/>
protected internal override Object GetObjectAt(int index)
{
if (index == 0 && m_VirtualRoot != null)
{
return m_VirtualRoot;
}
return m_Bodies[index];
}
/// <summary>
/// Get a dictionary indicating which Rigidbodies' poses are enabled or disabled.
/// </summary>
/// <returns></returns>
internal Dictionary<Rigidbody, bool> GetBodyPosesEnabled()
{
var bodyPosesEnabled = new Dictionary<Rigidbody, bool>(m_Bodies.Length);
for (var i = 0; i < m_Bodies.Length; i++)
{
var rb = m_Bodies[i];
if (rb == null)
{
continue; // skip virtual root
}
bodyPosesEnabled[rb] = IsPoseEnabled(i);
}
return bodyPosesEnabled;
}
internal IEnumerable<Rigidbody> GetEnabledRigidbodies()
{
if (m_Bodies == null)
{
yield break;
}
for (var i = 0; i < m_Bodies.Length; i++)
{
var rb = m_Bodies[i];
if (rb == null)
{
// Ignore a virtual root.
continue;
}
if (IsPoseEnabled(i))
{
yield return rb;
}
}
}
}
}

68
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs


using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

/// <summary>
/// Optional sensor name. This must be unique for each Agent.
/// </summary>
[SerializeField]
[SerializeField]
[HideInInspector]
RigidBodyPoseExtractor m_PoseExtractor;
/// <summary>
/// Creates a PhysicsBodySensor.
/// </summary>

return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName);
var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
}
/// <inheritdoc/>

return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot);
var poseExtractor = GetPoseExtractor();
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
foreach(var rb in poseExtractor.GetEnabledRigidbodies())
var body = poseExtractor.Bodies[i];
var joint = body?.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings);
var joint = rb.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings);
}
/// <summary>
/// Get the DisplayNodes of the hierarchy.
/// </summary>
/// <returns></returns>
internal IList<PoseExtractor.DisplayNode> GetDisplayNodes()
{
return GetPoseExtractor().GetDisplayNodes();
}
/// <summary>
/// Lazy construction of the PoseExtractor.
/// </summary>
/// <returns></returns>
RigidBodyPoseExtractor GetPoseExtractor()
{
if (m_PoseExtractor == null)
{
ResetPoseExtractor();
}
return m_PoseExtractor;
}
/// <summary>
/// Reset the pose extractor, trying to keep the enabled state of the corresponding poses the same.
/// </summary>
internal void ResetPoseExtractor()
{
// Get the current enabled state of each body, so that we can reinitialize with them.
Dictionary<Rigidbody, bool> bodyPosesEnabled = null;
if (m_PoseExtractor != null)
{
bodyPosesEnabled = m_PoseExtractor.GetBodyPosesEnabled();
}
m_PoseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot, bodyPosesEnabled);
}
/// <summary>
/// Toggle the pose at the given index.
/// </summary>
/// <param name="index"></param>
/// <param name="enabled"></param>
internal void SetPoseEnabled(int index, bool enabled)
{
GetPoseExtractor().SetPoseEnabled(index, enabled);
}
}

88
com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs


using System;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;

public class PoseExtractorTests
{
class UselessPoseExtractor : PoseExtractor
class BasicPoseExtractor : PoseExtractor
{
protected internal override Pose GetPoseAt(int index)
{

protected internal override Vector3 GetLinearVelocityAt(int index)
protected internal override Vector3 GetLinearVelocityAt(int index)
}
class UselessPoseExtractor : BasicPoseExtractor
{
public void Init(int[] parentIndices)
{
Setup(parentIndices);

poseExtractor.UpdateModelSpacePoses();
Assert.AreEqual(0, poseExtractor.NumPoses);
// Iterating through poses and velocities should be an empty loop
foreach (var pose in poseExtractor.GetEnabledModelSpacePoses())
{
throw new UnityAgentsException("This shouldn't happen");
}
foreach (var pose in poseExtractor.GetEnabledLocalSpacePoses())
{
throw new UnityAgentsException("This shouldn't happen");
}
foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities())
{
throw new UnityAgentsException("This shouldn't happen");
}
foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities())
{
throw new UnityAgentsException("This shouldn't happen");
}
// Getting a parent index should throw an index exception
Assert.Throws <NullReferenceException>(
() => poseExtractor.GetParentIndex(0)
);
// DisplayNodes should be empty
var displayNodes = poseExtractor.GetDisplayNodes();
Assert.AreEqual(0, displayNodes.Count);
}
[Test]

Assert.AreEqual(size, localPoseIndex);
}
class BadPoseExtractor : PoseExtractor
[Test]
public void TestChainDisplayNodes()
{
var size = 4;
var chain = new ChainPoseExtractor(size);
var displayNodes = chain.GetDisplayNodes();
Assert.AreEqual(size, displayNodes.Count);
for (var i = 0; i < size; i++)
{
var displayNode = displayNodes[i];
Assert.AreEqual(i, displayNode.OriginalIndex);
Assert.AreEqual(null, displayNode.NodeObject);
Assert.AreEqual(i, displayNode.Depth);
Assert.AreEqual(true, displayNode.Enabled);
}
}
[Test]
public void TestDisplayNodesLoop()
{
// Degenerate case with a loop
var poseExtractor = new UselessPoseExtractor();
poseExtractor.Init(new[] {-1, 2, 1});
// This just shouldn't blow up
poseExtractor.GetDisplayNodes();
// Self-loop
poseExtractor.Init(new[] {-1, 1});
// This just shouldn't blow up
poseExtractor.GetDisplayNodes();
}
class BadPoseExtractor : BasicPoseExtractor
{
public BadPoseExtractor()
{

}
Setup(parents);
}
protected internal override Pose GetPoseAt(int index)
{
return Pose.identity;
}
protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
}
[Test]

var bad = new BadPoseExtractor();
});
}
}
public class PoseExtensionTests

67
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs


var rootRb = go.AddComponent<Rigidbody>();
var poseExtractor = new RigidBodyPoseExtractor(rootRb);
Assert.AreEqual(1, poseExtractor.NumPoses);
// Also pass the GameObject
poseExtractor = new RigidBodyPoseExtractor(rootRb, go);
Assert.AreEqual(1, poseExtractor.NumPoses);
}
[Test]
public void TestNoBodiesFound()
{
// Check that if we can't find any bodies under the game object, we get an empty extractor
var gameObj = new GameObject();
var rootRb = gameObj.AddComponent<Rigidbody>();
var otherGameObj = new GameObject();
var poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj);
Assert.AreEqual(0, poseExtractor.NumPoses);
// Add an RB under the other GameObject. Constructor will find a rigid body, but not the root.
var otherRb = otherGameObj.AddComponent<Rigidbody>();
poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj);
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]

Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0));
// Check DisplayNodes gives expected results
var displayNodes = poseExtractor.GetDisplayNodes();
Assert.AreEqual(2, displayNodes.Count);
Assert.AreEqual(rb1, displayNodes[0].NodeObject);
Assert.AreEqual(false, displayNodes[0].Enabled);
Assert.AreEqual(rb2, displayNodes[1].NodeObject);
Assert.AreEqual(true, displayNodes[1].Enabled);
}
[Test]

Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1));
}
[Test]
public void TestBodyPosesEnabledDictionary()
{
// * rootObj
// - rb1
// * go2
// - rb2
// - joint
var rootObj = new GameObject();
var rb1 = rootObj.AddComponent<Rigidbody>();
var go2 = new GameObject();
var rb2 = go2.AddComponent<Rigidbody>();
go2.transform.SetParent(rootObj.transform);
var joint = go2.AddComponent<ConfigurableJoint>();
joint.connectedBody = rb1;
var poseExtractor = new RigidBodyPoseExtractor(rb1);
// Expect the root body disabled and the attached one enabled.
Assert.IsFalse(poseExtractor.IsPoseEnabled(0));
Assert.IsTrue(poseExtractor.IsPoseEnabled(1));
var bodyPosesEnabled = poseExtractor.GetBodyPosesEnabled();
Assert.IsFalse(bodyPosesEnabled[rb1]);
Assert.IsTrue(bodyPosesEnabled[rb2]);
// Swap the values
bodyPosesEnabled[rb1] = true;
bodyPosesEnabled[rb2] = false;
var poseExtractor2 = new RigidBodyPoseExtractor(rb1, null, null, bodyPosesEnabled);
Assert.IsTrue(poseExtractor2.IsPoseEnabled(0));
Assert.IsFalse(poseExtractor2.IsPoseEnabled(1));
}
}
}

5
com.unity.ml-agents/CHANGELOG.md


#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
## [1.3.0-preview] 2020-08-12
## [1.3.0-preview] - 2020-08-12
### Major Changes
#### com.unity.ml-agents (C#)

Previously, this would result in an infinite loop and cause the editor to hang.
(#4226)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The algorithm used to normalize observations was introducing NaNs if the initial observations were too large
due to incorrect initialization. The initialization was fixed and is now the observation means from the
first trajectory processed. (#4299)
## [1.2.0-preview] - 2020-07-15

33
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


/// the offset into the original array, and an length.
/// </summary>
/// <typeparam name="T">The type of object stored in the underlying <see cref="Array"/></typeparam>
internal readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
public readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
where T : struct
{
/// <summary>

/// </summary>
public static ActionSegment<T> Empty = new ActionSegment<T>(System.Array.Empty<T>(), 0, 0);
static void CheckParameters(T[] actionArray, int offset, int length)
static void CheckParameters(IReadOnlyCollection<T> actionArray, int offset, int length)
if (offset + length > actionArray.Length)
if (offset + length > actionArray.Count)
$"are out of bounds of actionArray: {actionArray.Length}.");
$"are out of bounds of actionArray: {actionArray.Count}.");
/// Construct an <see cref="ActionSegment{T}"/> with just an actionArray. The <see cref="Offset"/> will
/// be set to 0 and the <see cref="Length"/> will be set to `actionArray.Length`.
/// </summary>
/// <param name="actionArray">The action array to use for the this segment.</param>
public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) { }
/// <summary>
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array
/// and offset, and a length.
/// </summary>

public ActionSegment(T[] actionArray, int offset, int length)
{
#if DEBUG
#endif
Array = actionArray;
Offset = offset;
Length = length;

}
return Array[Offset + index];
}
set
{
if (index < 0 || index > Length)
{
throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}");
}
Array[Offset + index] = value;
}
}
/// <summary>
/// Sets the segment of the backing array to all zeros.
/// </summary>
public void Clear()
{
System.Array.Clear(Array, Offset, Length);
}
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>

2
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


/// <summary>
/// Defines the structure of an Action Space to be used by the Actuator system.
/// </summary>
internal readonly struct ActionSpec
public readonly struct ActionSpec
{
/// <summary>

51
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public float[] StoredContinuousActions { get; private set; }
// public float[] StoredContinuousActions { get; private set; }
public int[] StoredDiscreteActions { get; private set; }
// public int[] StoredDiscreteActions { get; private set; }
public ActionBuffers StoredActions { get; private set; }
/// <summary>
/// Create an ActuatorList with a preset capacity.

// Sort the Actuators by name to ensure determinism
SortActuators();
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty<float>() : new float[numContinuousActions];
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty<int>() : new int[numDiscreteBranches];
var continuousActions = numContinuousActions == 0 ? ActionSegment<float>.Empty :
new ActionSegment<float>(new float[numContinuousActions]);
var discreteActions = numDiscreteBranches == 0 ? ActionSegment<int>.Empty : new ActionSegment<int>(new int[numDiscreteBranches]);
StoredActions = new ActionBuffers(continuousActions, discreteActions);
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
m_ReadyForExecution = true;
}

/// continuous actions for the IActuators in this list.</param>
/// <param name="discreteActionBuffer">The action buffer which contains all of the
/// discrete actions for the IActuators in this list.</param>
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
public void UpdateActions(ActionBuffers actions)
UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions);
UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions);
static void UpdateActionArray<T>(T[] sourceActionBuffer, T[] destination)
static void UpdateActionArray<T>(ActionSegment<T> sourceActionBuffer, ActionSegment<T> destination)
where T : struct
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
if (sourceActionBuffer.Length <= 0)
Array.Clear(destination, 0, destination.Length);
destination.Clear();
}
else
{

Array.Copy(sourceActionBuffer, destination, destination.Length);
Array.Copy(sourceActionBuffer.Array,
sourceActionBuffer.Offset,
destination.Array,
destination.Offset,
destination.Length);
}
}

for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
if (actuator.ActionSpec.NumDiscreteActions > 0)
{
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
}

var continuousActions = ActionSegment<float>.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment<float>(StoredContinuousActions,
continuousActions = new ActionSegment<float>(StoredActions.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}

{
discreteActions = new ActionSegment<int>(StoredDiscreteActions,
discreteActions = new ActionSegment<int>(StoredActions.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}

}
/// <summary>
/// Resets the <see cref="StoredContinuousActions"/> and <see cref="StoredDiscreteActions"/> buffers to be all
/// Resets the <see cref="ActionBuffers"/> to be all
/// zeros and calls <see cref="IActuator.ResetData"/> on each <see cref="IActuator"/> managed by this object.
/// </summary>
public void ResetData()

return;
}
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
StoredActions.Clear();
m_DiscreteActionMask.ResetMask();
}

72
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


using System;
using System.Linq;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{

/// </summary>
internal readonly struct ActionBuffers
public readonly struct ActionBuffers
{
/// <summary>
/// An empty action buffer.

public ActionSegment<int> DiscreteActions { get; }
/// <summary>
/// Create an <see cref="ActionBuffers"/> instance with discrete actions stored as a float array. This exists
/// to achieve backward compatibility with the former Agent methods which used a float array for both continuous
/// and discrete actions.
/// </summary>
/// <param name="discreteActions">The float array of discrete actions.</param>
/// <returns>An <see cref="ActionBuffers"/> instance initialized with a <see cref="DiscreteActions"/>
/// <see cref="ActionSegment{T}"/> initialized from a float array.</returns>
public static ActionBuffers FromDiscreteActions(float[] discreteActions)
{
return new ActionBuffers(ActionSegment<float>.Empty, discreteActions == null ? ActionSegment<int>.Empty
: new ActionSegment<int>(Array.ConvertAll(discreteActions,
x => (int)x)));
}
public ActionBuffers(float[] continuousActions, int[] discreteActions)
: this(new ActionSegment<float>(continuousActions), new ActionSegment<int>(discreteActions)) { }
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with the continuous and discrete actions that will
/// be used.
/// </summary>

DiscreteActions = discreteActions;
}
/// <summary>
/// Clear the <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/> segments to be all zeros.
/// </summary>
public void Clear()
{
ContinuousActions.Clear();
DiscreteActions.Clear();
}
/// <inheritdoc cref="ValueType.Equals(object)"/>
public override bool Equals(object obj)
{

return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
/// <summary>
/// Packs the continuous and discrete actions into one float array. The array passed into this method
/// must have a Length that is greater than or equal to the sum of the Lengths of
/// <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/>.
/// </summary>
/// <param name="destination">A float array to pack actions into whose length is greater than or
/// equal to the addition of the Lengths of this objects <see cref="ContinuousActions"/> and
/// <see cref="DiscreteActions"/> segments.</param>
public void PackActions(in float[] destination)
{
Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length,
$"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" +
$"{nameof(destination)}.Length: {destination.Length}\n" +
$"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}");
var start = 0;
if (ContinuousActions.Length > 0)
{
Array.Copy(ContinuousActions.Array,
ContinuousActions.Offset,
destination,
start,
ContinuousActions.Length);
start = ContinuousActions.Length;
}
if (start >= destination.Length)
{
return;
}
if (DiscreteActions.Length > 0)
{
Array.Copy(DiscreteActions.Array,
DiscreteActions.Offset,
destination,
start,
DiscreteActions.Length);
}
}
internal interface IActionReceiver
public interface IActionReceiver
{
/// <summary>

2
com.unity.ml-agents/Runtime/Actuators/IActuator.cs


/// <summary>
/// Abstraction that facilitates the execution of actions.
/// </summary>
internal interface IActuator : IActionReceiver
public interface IActuator : IActionReceiver
{
int TotalNumberOfActions { get; }

2
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


/// <summary>
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
/// </summary>
internal interface IDiscreteActionMask
public interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.

2
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs


namespace Unity.MLAgents.Actuators
{
internal class VectorActuator : IActuator
public class VectorActuator : IActuator
{
IActionReceiver m_ActionReceiver;

201
com.unity.ml-agents/Runtime/Agent.cs


using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Demonstrations;

/// to separate between different agents in the environment.
/// </summary>
public int episodeId;
}
/// <summary>
/// Struct that contains the action information sent from the Brain to the
/// Agent.
/// </summary>
internal struct AgentAction
{
public float[] vectorActions;
public void ClearActions()
{
Array.Clear(storedVectorActions, 0, storedVectorActions.Length);
}
public void CopyActions(ActionBuffers actionBuffers)
{
actionBuffers.PackActions(storedVectorActions);
}
}
/// <summary>

/// can only take an action when it touches the ground, so several frames might elapse between
/// one decision and the need for the next.
///
/// Use the <see cref="OnActionReceived"/> function to implement the actions your agent can take,
/// Use the <see cref="OnActionReceived(float[])"/> function to implement the actions your agent can take,
/// such as moving to reach a goal or interacting with its environment.
///
/// When you call <see cref="EndEpisode"/> on an agent or the agent reaches its <see cref="MaxStep"/> count,

"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]
public class Agent : MonoBehaviour, ISerializationCallbackReceiver
public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver
{
IPolicy m_Brain;
BehaviorParameters m_PolicyFactory;

/// Current Agent information (message sent to Brain).
AgentInfo m_Info;
/// Current Agent action (message sent from Brain).
AgentAction m_Action;
/// Represents the reward the agent accumulated during the current step.
/// It is reset to 0 at the beginning of every step.

internal VectorSensor collectObservationsSensor;