浏览代码

Merge remote-tracking branch 'origin/master' into release_6-to-master

/release_6_branch
Christopher Goy 4 年前
当前提交
5a233353
共有 201 个文件被更改,包括 2614 次插入4261 次删除
  1. 2
      .circleci/config.yml
  2. 9
      DevProject/Assets/ML-Agents/Scripts/Tests/Performance/SensorPerformanceTests.cs
  3. 16
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  4. 9
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
  5. 9
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicController.cs
  6. 30
      Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
  7. 37
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  8. 29
      Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
  9. 29
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  10. 23
      Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
  11. 23
      Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
  12. 23
      Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
  13. 12
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
  14. 20
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/DirectionIndicator.cs
  15. 31
      Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
  16. 4
      Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs
  17. 20
      Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  18. 977
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity
  19. 2
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity.meta
  20. 962
      Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerStatic.unity
  21. 218
      Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
  22. 1001
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn
  23. 2
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn.meta
  24. 1001
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn
  25. 2
      Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn.meta
  26. 31
      Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
  27. 17
      Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs
  28. 2
      Project/ProjectSettings/ProjectVersion.txt
  29. 4
      com.unity.ml-agents.extensions/Editor/Unity.ML-Agents.Extensions.Editor.asmdef
  30. 1
      com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
  31. 29
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  32. 8
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
  33. 54
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  34. 117
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  35. 80
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  36. 68
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  37. 88
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
  38. 67
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
  39. 29
      com.unity.ml-agents/CHANGELOG.md
  40. 201
      com.unity.ml-agents/Runtime/Agent.cs
  41. 14
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  42. 2
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  43. 6
      com.unity.ml-agents/Runtime/DecisionRequester.cs
  44. 118
      com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
  45. 17
      com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
  46. 37
      com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
  47. 20
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  48. 3
      com.unity.ml-agents/Runtime/Policies/IPolicy.cs
  49. 15
      com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
  50. 3
      com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
  51. 48
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  52. 7
      com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
  53. 2
      config/ppo/WalkerDynamic.yaml
  54. 2
      config/ppo/WalkerStatic.yaml
  55. 51
      docs/Learning-Environment-Create-New.md
  56. 31
      docs/Learning-Environment-Examples.md
  57. 2
      gym-unity/gym_unity/__init__.py
  58. 2
      ml-agents-envs/mlagents_envs/__init__.py
  59. 8
      ml-agents-envs/mlagents_envs/exception.py
  60. 2
      ml-agents/mlagents/trainers/__init__.py
  61. 2
      ml-agents/mlagents/trainers/ghost/trainer.py
  62. 3
      ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
  63. 15
      ml-agents/mlagents/trainers/policy/policy.py
  64. 115
      ml-agents/mlagents/trainers/policy/tf_policy.py
  65. 10
      ml-agents/mlagents/trainers/ppo/optimizer.py
  66. 17
      ml-agents/mlagents/trainers/ppo/trainer.py
  67. 2
      ml-agents/mlagents/trainers/sac/optimizer.py
  68. 17
      ml-agents/mlagents/trainers/sac/trainer.py
  69. 8
      ml-agents/mlagents/trainers/settings.py
  70. 41
      ml-agents/mlagents/trainers/stats.py
  71. 27
      ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
  72. 10
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  73. 62
      ml-agents/mlagents/trainers/tests/test_env_param_manager.py
  74. 62
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  75. 8
      ml-agents/mlagents/trainers/tests/test_ppo.py
  76. 1
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  77. 21
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  78. 6
      ml-agents/mlagents/trainers/tests/test_sac.py
  79. 8
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  80. 20
      ml-agents/mlagents/trainers/tests/test_tf_policy.py
  81. 26
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  82. 2
      ml-agents/mlagents/trainers/trainer/trainer.py
  83. 10
      ml-agents/mlagents/trainers/trainer_controller.py
  84. 3
      test_requirements.txt
  85. 2
      utils/validate_release_links.py
  86. 21
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Targets/DynamicTarget.prefab
  87. 19
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Targets/StaticTarget.prefab
  88. 82
      Project/Assets/ML-Agents/Examples/Walker/Prefabs/Ragdoll/WalkerRagdollBase.prefab
  89. 523
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/PlatformDynamicTarget.prefab
  90. 7
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/PlatformDynamicTarget.prefab.meta
  91. 8
      Project/Assets/ML-Agents/Examples/SharedAssets/Prefabs/Targets.meta
  92. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerDy.demo.meta
  93. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerDyVS.demo.meta
  94. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerStVS.demo.meta
  95. 10
      Project/Assets/ML-Agents/Examples/Walker/Demos/ExpertWalkerSta.demo.meta

2
.circleci/config.yml


. venv/bin/activate
mkdir test-reports
pip freeze > test-reports/pip_versions.txt
pytest -n 2 --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=test-reports/junit.xml -p no:warnings
pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=test-reports/junit.xml -p no:warnings
- run:
name: Verify there are no hidden/missing metafiles.

9
DevProject/Assets/ML-Agents/Scripts/Tests/Performance/SensorPerformanceTests.cs


using NUnit.Framework;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;

{
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

sensor.AddObservation(new Quaternion(1, 2, 3, 4));
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

[Observable]
public Quaternion QuaternionField = new Quaternion(1, 2, 3, 4);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

get { return m_QuaternionField; }
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

16
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


using System;
using Unity.MLAgents.Actuators;
using Random = UnityEngine.Random;
public class Ball3DAgent : Agent
{

sensor.AddObservation(m_BallRb.velocity);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
(gameObject.transform.rotation.z > -0.25f && actionZ < 0f))

SetResetParameters();
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = -Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetAxis("Vertical");
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = -Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetAxis("Vertical");
}
public void SetBall()

9
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class Ball3DHardAgent : Agent

sensor.AddObservation((ball.transform.position - gameObject.transform.position));
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
var continuousActions = actionBuffers.ContinuousActions;
var actionZ = 2f * Mathf.Clamp(continuousActions[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(continuousActions[1], -1f, 1f);
if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
(gameObject.transform.rotation.z > -0.25f && actionZ < 0f))

9
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicController.cs


using UnityEngine;
using UnityEngine.SceneManagement;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using UnityEngine.Serialization;
/// <summary>

/// Controls the movement of the GameObject based on the actions received.
/// </summary>
/// <param name="vectorAction"></param>
public void ApplyAction(float[] vectorAction)
public void ApplyAction(ActionSegment<int> vectorAction)
var movement = (int)vectorAction[0];
var movement = vectorAction[0];
var direction = 0;

if (Academy.Instance.IsCommunicatorOn)
{
// Apply the previous step's actions
ApplyAction(m_Agent.GetAction());
ApplyAction(m_Agent.GetStoredActionBuffers().DiscreteActions);
m_Agent?.RequestDecision();
}
else

// Apply the previous step's actions
ApplyAction(m_Agent.GetAction());
ApplyAction(m_Agent.GetStoredActionBuffers().DiscreteActions);
m_TimeSinceDecision = 0f;
m_Agent?.RequestDecision();

30
Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class BouncerAgent : Agent

sensor.AddObservation(target.transform.localPosition);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
for (var i = 0; i < vectorAction.Length; i++)
var continuousActions = actionBuffers.ContinuousActions;
for (var i = 0; i < continuousActions.Length; i++)
vectorAction[i] = Mathf.Clamp(vectorAction[i], -1f, 1f);
continuousActions[i] = Mathf.Clamp(continuousActions[i], -1f, 1f);
var x = vectorAction[0];
var y = ScaleAction(vectorAction[1], 0, 1);
var z = vectorAction[2];
var x = continuousActions[0];
var y = ScaleAction(continuousActions[1], 0, 1);
var z = continuousActions[2];
vectorAction[0] * vectorAction[0] +
vectorAction[1] * vectorAction[1] +
vectorAction[2] * vectorAction[2]) / 3f);
continuousActions[0] * continuousActions[0] +
continuousActions[1] * continuousActions[1] +
continuousActions[2] * continuousActions[2]) / 3f);
m_LookDir = new Vector3(x, y, z);
}

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
actionsOut[2] = Input.GetAxis("Vertical");
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
continuousActionsOut[2] = Input.GetAxis("Vertical");
}
void Update()

37
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;

AddReward(1f);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var continuousActions = actionBuffers.ContinuousActions;
bpDict[leg0Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg1Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg2Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg3Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg0Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg1Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg2Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg3Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg0Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg1Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg2Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg3Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg0Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg1Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg2Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg3Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg0Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg1Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg2Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg3Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg0Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg1Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg2Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg3Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg0Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg1Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg2Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg3Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg0Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg1Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg2Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg3Lower].SetJointStrength(continuousActions[++i]);
}
void FixedUpdate()

29
Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs


using System;
using Unity.MLAgents.Actuators;
using Random = UnityEngine.Random;
public class FoodCollectorAgent : Agent
{

return new Color32(r, g, b, 255);
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
{
m_Shoot = false;

gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0f;
actionsOut[1] = 0f;
actionsOut[2] = 0f;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
discreteActionsOut[1] = 0;
discreteActionsOut[2] = 0;
actionsOut[2] = 2f;
discreteActionsOut[2] = 2;
actionsOut[0] = 1f;
discreteActionsOut[0] = 1;
actionsOut[2] = 1f;
discreteActionsOut[2] = 1;
actionsOut[0] = 2f;
discreteActionsOut[0] = 2;
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
discreteActionsOut[3] = Input.GetKey(KeyCode.Space) ? 1 : 0;
}
public override void OnEpisodeBegin()

29
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using UnityEngine.Serialization;
public class GridAgent : Agent

m_ResetParams = Academy.Instance.EnvironmentParameters;
}
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
// Mask the necessary actions if selected by the user.
if (maskActions)

if (positionX == 0)
{
actionMasker.SetMask(0, new []{ k_Left});
actionMask.WriteMask(0, new []{ k_Left});
actionMasker.SetMask(0, new []{k_Right});
actionMask.WriteMask(0, new []{k_Right});
actionMasker.SetMask(0, new []{k_Down});
actionMask.WriteMask(0, new []{k_Down});
actionMasker.SetMask(0, new []{k_Up});
actionMask.WriteMask(0, new []{k_Up});
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var action = Mathf.FloorToInt(vectorAction[0]);
var action = actionBuffers.DiscreteActions[0];
var targetPos = transform.position;
switch (action)

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = k_NoAction;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = k_NoAction;
actionsOut[0] = k_Right;
discreteActionsOut[0] = k_Right;
actionsOut[0] = k_Up;
discreteActionsOut[0] = k_Up;
actionsOut[0] = k_Left;
discreteActionsOut[0] = k_Left;
actionsOut[0] = k_Down;
discreteActionsOut[0] = k_Down;
}
}

23
Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs


using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class HallwayAgent : Agent

m_GroundRenderer.material = m_GroundMaterial;
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
var action = Mathf.FloorToInt(act[0]);
var action = act[0];
switch (action)
{
case 1:

m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
}
void OnCollisionEnter(Collision col)

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
actionsOut[0] = 3;
discreteActionsOut[0] = 3;
actionsOut[0] = 1;
discreteActionsOut[0] = 1;
actionsOut[0] = 4;
discreteActionsOut[0] = 4;
actionsOut[0] = 2;
discreteActionsOut[0] = 2;
}
}

23
Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs


using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
public class PushAgentBasic : Agent
{

/// <summary>
/// Moves the agent according to the selected action.
/// </summary>
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
var action = Mathf.FloorToInt(act[0]);
var action = act[0];
switch (action)
{

/// <summary>
/// Called every step of the engine. Here the agent takes an action.
/// </summary>
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
actionsOut[0] = 3;
discreteActionsOut[0] = 3;
actionsOut[0] = 1;
discreteActionsOut[0] = 1;
actionsOut[0] = 4;
discreteActionsOut[0] = 4;
actionsOut[0] = 2;
discreteActionsOut[0] = 2;
}
}

23
Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs


using UnityEngine;
using Random = UnityEngine.Random;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class PyramidAgent : Agent

}
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
var action = Mathf.FloorToInt(act[0]);
var action = act[0];
switch (action)
{
case 1:

m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
actionsOut[0] = 3;
discreteActionsOut[0] = 3;
actionsOut[0] = 1;
discreteActionsOut[0] = 1;
actionsOut[0] = 4;
discreteActionsOut[0] = 4;
actionsOut[0] = 2;
discreteActionsOut[0] = 2;
}
}

12
Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class ReacherAgent : Agent

/// <summary>
/// The agent's four actions correspond to torques on each of the two joints.
/// </summary>
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var torqueX = Mathf.Clamp(vectorAction[0], -1f, 1f) * 150f;
var torqueZ = Mathf.Clamp(vectorAction[1], -1f, 1f) * 150f;
var torqueX = Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f) * 150f;
var torqueZ = Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f) * 150f;
torqueX = Mathf.Clamp(vectorAction[2], -1f, 1f) * 150f;
torqueZ = Mathf.Clamp(vectorAction[3], -1f, 1f) * 150f;
torqueX = Mathf.Clamp(actionBuffers.ContinuousActions[2], -1f, 1f) * 150f;
torqueZ = Mathf.Clamp(actionBuffers.ContinuousActions[3], -1f, 1f) * 150f;
m_RbB.AddTorque(new Vector3(torqueX, 0f, torqueZ));
}

20
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/DirectionIndicator.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine;
public bool updatedByAgent; //should this be updated by the agent? If not, it will use local settings
void OnEnable()
{
m_StartingYPos = transform.position.y;

{
transform.position = new Vector3(transformToFollow.position.x, m_StartingYPos + heightOffset, transformToFollow.position.z);
if (updatedByAgent)
return;
transform.position = new Vector3(transformToFollow.position.x, m_StartingYPos + heightOffset,
transformToFollow.position.z);
}
//Public method to allow an agent to directly update this component
public void MatchOrientation(Transform t)
{
transform.position = new Vector3(t.position.x, m_StartingYPos + heightOffset, t.position.z);
transform.rotation = t.rotation;
}
}
}

31
Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs


using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
public class AgentSoccer : Agent

m_ResetParams = Academy.Instance.EnvironmentParameters;
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;

var forwardAxis = (int)act[0];
var rightAxis = (int)act[1];
var rotateAxis = (int)act[2];
var forwardAxis = act[0];
var rightAxis = act[1];
var rotateAxis = act[2];
switch (forwardAxis)
{

ForceMode.VelocityChange);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
{
if (position == Position.Goalie)

// Existential penalty cumulant for Generic
timePenalty -= m_Existential;
}
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
Array.Clear(actionsOut, 0, actionsOut.Length);
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
actionsOut[0] = 1f;
discreteActionsOut[0] = 1;
actionsOut[0] = 2f;
discreteActionsOut[0] = 2;
actionsOut[2] = 1f;
discreteActionsOut[2] = 1;
actionsOut[2] = 2f;
discreteActionsOut[2] = 2;
actionsOut[1] = 1f;
discreteActionsOut[1] = 1;
actionsOut[1] = 2f;
discreteActionsOut[1] = 2;
}
}
/// <summary>

4
Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class TemplateAgent : Agent

}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
{
}

20
Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


using UnityEngine;
using UnityEngine.UI;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class TennisAgent : Agent

sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var moveX = Mathf.Clamp(vectorAction[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(vectorAction[1], -1f, 1f);
var rotate = Mathf.Clamp(vectorAction[2], -1f, 1f) * m_InvertMult;
var continuousActions = actionBuffers.ContinuousActions;
var moveX = Mathf.Clamp(continuousActions[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(continuousActions[1], -1f, 1f);
var rotate = Mathf.Clamp(continuousActions[2], -1f, 1f) * m_InvertMult;
if (moveY > 0.5 && transform.position.y - transform.parent.transform.position.y < -1.5f)
{

m_TextComponent.text = score.ToString();
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = Input.GetAxis("Horizontal"); // Racket Movement
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
actionsOut[2] = Input.GetAxis("Vertical"); // Racket Rotation
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Horizontal"); // Racket Movement
continuousActionsOut[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
continuousActionsOut[2] = Input.GetAxis("Vertical"); // Racket Rotation
}
public override void OnEpisodeBegin()

977
Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity
文件差异内容过多而无法显示
查看文件

2
Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerDynamic.unity.meta


fileFormatVersion: 2
guid: 79d5d2687bfbe45f5b78bd6c04992e0d
guid: 65c87f50b8c81433d8fd7f6550773467
DefaultImporter:
externalObjects: {}
userData:

962
Project/Assets/ML-Agents/Examples/Walker/Scenes/WalkerStatic.unity
文件差异内容过多而无法显示
查看文件

218
Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


using System;
using MLAgentsExamples;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;
using BodyPart = Unity.MLAgentsExamples.BodyPart;

{
public float maximumWalkingSpeed = 999; //The max walk velocity magnitude an agent will be rewarded for
Vector3 m_WalkDir; //Direction to the target
// Quaternion m_WalkDirLookRot; //Will hold the rotation to our target
[Header("Walk Speed")]
[Range(0.1f, 10)]
[SerializeField]
//The walking speed to try and achieve
private float m_TargetWalkingSpeed = 10;
[Header("Target To Walk Towards")] [Space(10)]
public TargetController target; //Target the agent will walk towards.
public float MTargetWalkingSpeed // property
{
get { return m_TargetWalkingSpeed; }
set { m_TargetWalkingSpeed = Mathf.Clamp(value, .1f, m_maxWalkingSpeed); }
}
const float m_maxWalkingSpeed = 10; //The max walking speed
//Should the agent sample a new goal velocity each episode?
//If true, walkSpeed will be randomly set between zero and m_maxWalkingSpeed in OnEpisodeBegin()
//If false, the goal velocity will be walkingSpeed
public bool randomizeWalkSpeedEachEpisode;
//The direction an agent will walk during training.
private Vector3 m_WorldDirToWalk = Vector3.right;
[Header("Target To Walk Towards")] public Transform target; //Target the agent will walk towards during training.
[Header("Body Parts")] [Space(10)] public Transform hips;
[Header("Body Parts")] public Transform hips;
public Transform chest;
public Transform spine;
public Transform head;

public Transform forearmR;
public Transform handR;
[Header("Orientation")] [Space(10)]
public OrientationCubeController orientationCube;
OrientationCubeController m_OrientationCube;
//The indicator graphic gameobject that points towards the target
DirectionIndicator m_DirectionIndicator;
orientationCube.UpdateOrientation(hips, target.transform);
m_OrientationCube = GetComponentInChildren<OrientationCubeController>();
m_DirectionIndicator = GetComponentInChildren<DirectionIndicator>();
//Setup each body part
m_JdController = GetComponent<JointDriveController>();

}
//Random start rotation to help generalize
transform.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);
hips.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);
UpdateOrientationObjects();
orientationCube.UpdateOrientation(hips, target.transform);
//Set our goal walking speed
MTargetWalkingSpeed =
randomizeWalkSpeedEachEpisode ? Random.Range(0.1f, m_maxWalkingSpeed) : MTargetWalkingSpeed;
SetResetParameters();
}

//Get velocities in the context of our orientation cube's space
//Note: You can get these velocities in world space as well but it may not train as well.
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.velocity));
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.velocity));
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.position - hips.position));
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.position - hips.position));
if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR)
{

/// </summary>
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(Quaternion.FromToRotation(hips.forward, orientationCube.transform.forward));
sensor.AddObservation(Quaternion.FromToRotation(head.forward, orientationCube.transform.forward));
var cubeForward = m_OrientationCube.transform.forward;
//velocity we want to match
var velGoal = cubeForward * MTargetWalkingSpeed;
//ragdoll's avg vel
var avgVel = GetAvgVelocity();
//current ragdoll velocity. normalized
sensor.AddObservation(Vector3.Distance(velGoal, avgVel));
//avg body vel relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel));
//vel goal relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal));
//rotation deltas
sensor.AddObservation(Quaternion.FromToRotation(hips.forward, cubeForward));
sensor.AddObservation(Quaternion.FromToRotation(head.forward, cubeForward));
sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position));
//Position of target position relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(target.transform.position));
foreach (var bodyPart in m_JdController.bodyPartsList)
{

public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
bpDict[chest].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
bpDict[spine].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
var continuousActions = actionBuffers.ContinuousActions;
bpDict[chest].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[spine].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[thighL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[thighR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[shinL].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[shinR].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[footR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
bpDict[footL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
bpDict[thighL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[thighR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[shinL].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[shinR].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[footR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[footL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[armL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[armR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[forearmL].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[forearmR].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[head].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[armL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[armR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[forearmL].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[forearmR].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[head].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[chest].SetJointStrength(vectorAction[++i]);
bpDict[spine].SetJointStrength(vectorAction[++i]);
bpDict[head].SetJointStrength(vectorAction[++i]);
bpDict[thighL].SetJointStrength(vectorAction[++i]);
bpDict[shinL].SetJointStrength(vectorAction[++i]);
bpDict[footL].SetJointStrength(vectorAction[++i]);
bpDict[thighR].SetJointStrength(vectorAction[++i]);
bpDict[shinR].SetJointStrength(vectorAction[++i]);
bpDict[footR].SetJointStrength(vectorAction[++i]);
bpDict[armL].SetJointStrength(vectorAction[++i]);
bpDict[forearmL].SetJointStrength(vectorAction[++i]);
bpDict[armR].SetJointStrength(vectorAction[++i]);
bpDict[forearmR].SetJointStrength(vectorAction[++i]);
bpDict[chest].SetJointStrength(continuousActions[++i]);
bpDict[spine].SetJointStrength(continuousActions[++i]);
bpDict[head].SetJointStrength(continuousActions[++i]);
bpDict[thighL].SetJointStrength(continuousActions[++i]);
bpDict[shinL].SetJointStrength(continuousActions[++i]);
bpDict[footL].SetJointStrength(continuousActions[++i]);
bpDict[thighR].SetJointStrength(continuousActions[++i]);
bpDict[shinR].SetJointStrength(continuousActions[++i]);
bpDict[footR].SetJointStrength(continuousActions[++i]);
bpDict[armL].SetJointStrength(continuousActions[++i]);
bpDict[forearmL].SetJointStrength(continuousActions[++i]);
bpDict[armR].SetJointStrength(continuousActions[++i]);
bpDict[forearmR].SetJointStrength(continuousActions[++i]);
}
//Update OrientationCube and DirectionIndicator
void UpdateOrientationObjects()
{
m_WorldDirToWalk = target.position - hips.position;
m_OrientationCube.UpdateOrientation(hips, target);
if (m_DirectionIndicator)
{
m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform);
}
var cubeForward = orientationCube.transform.forward;
orientationCube.UpdateOrientation(hips, target.transform);
UpdateOrientationObjects();
var cubeForward = m_OrientationCube.transform.forward;
// a. Velocity alignment with goal direction.
var moveTowardsTargetReward = Vector3.Dot(cubeForward,
Vector3.ClampMagnitude(m_JdController.bodyPartsDict[hips].rb.velocity, maximumWalkingSpeed));
if (float.IsNaN(moveTowardsTargetReward))
// a. Match target speed
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
var matchSpeedReward = GetMatchingVelocityReward(cubeForward * MTargetWalkingSpeed, GetAvgVelocity());
//Check for NaNs
if (float.IsNaN(matchSpeedReward))
$" cubeForward: {cubeForward}\n"+
$" hips.velocity: {m_JdController.bodyPartsDict[hips].rb.velocity}\n"+
$" maximumWalkingSpeed: {maximumWalkingSpeed}"
$" cubeForward: {cubeForward}\n" +
$" hips.velocity: {m_JdController.bodyPartsDict[hips].rb.velocity}\n" +
$" maximumWalkingSpeed: {m_maxWalkingSpeed}"
// b. Rotation alignment with goal direction.
var lookAtTargetReward = Vector3.Dot(cubeForward, head.forward);
// b. Rotation alignment with target direction.
//This reward will approach 1 if it faces the target direction perfectly and approach zero as it deviates
var lookAtTargetReward = (Vector3.Dot(cubeForward, head.forward) + 1) * .5F;
//Check for NaNs
$" cubeForward: {cubeForward}\n"+
$" cubeForward: {cubeForward}\n" +
// c. Encourage head height. //Should normalize to ~1
var headHeightOverFeetReward =
((head.position.y - footL.position.y) + (head.position.y - footR.position.y) / 10);
if (float.IsNaN(headHeightOverFeetReward))
AddReward(matchSpeedReward * lookAtTargetReward);
}
//Returns the average velocity of all of the body parts
//Using the velocity of the hips only has shown to result in more erratic movement from the limbs, so...
//...using the average helps prevent this erratic movement
Vector3 GetAvgVelocity()
{
Vector3 velSum = Vector3.zero;
Vector3 avgVel = Vector3.zero;
//ALL RBS
int numOfRB = 0;
foreach (var item in m_JdController.bodyPartsList)
throw new ArgumentException(
"NaN in headHeightOverFeetReward.\n" +
$" head.position: {head.position}\n"+
$" footL.position: {footL.position}\n"+
$" footR.position: {footR.position}"
);
numOfRB++;
velSum += item.rb.velocity;
AddReward(
+ 0.02f * moveTowardsTargetReward
+ 0.02f * lookAtTargetReward
+ 0.005f * headHeightOverFeetReward
);
avgVel = velSum / numOfRB;
return avgVel;
}
//normalized value of the difference in avg speed vs goal walking speed.
public float GetMatchingVelocityReward(Vector3 velocityGoal, Vector3 actualVelocity)
{
//distance between our actual velocity and goal velocity
var velDeltaMagnitude = Mathf.Clamp(Vector3.Distance(actualVelocity, velocityGoal), 0, MTargetWalkingSpeed);
//return the value on a declining sigmoid shaped curve that decays from 1 to 0
//This reward will approach 1 if it matches perfectly and approach zero as it deviates
return Mathf.Pow(1 - Mathf.Pow(velDeltaMagnitude / MTargetWalkingSpeed, 2), 2);
}
/// <summary>

1001
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn
文件差异内容过多而无法显示
查看文件

2
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn.meta


fileFormatVersion: 2
guid: e785133c5b0ac461588106642550d1b3
guid: 8cbae6de45ea44d0c97366e252052722
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj

1001
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn
文件差异内容过多而无法显示
查看文件

2
Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn.meta


fileFormatVersion: 2
guid: 8dfd4337ed40e4d48872a4f86919c9da
guid: 185990f76b7804d1e83378e9d4454c6b
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj

31
Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.Barracuda;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgentsExamples;

m_GroundRenderer.material = m_GroundMaterial;
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
{
AddReward(-0.0005f);
var smallGrounded = DoGroundCheck(true);

var rotateDir = Vector3.zero;
var dirToGoForwardAction = (int)act[0];
var rotateDirAction = (int)act[1];
var dirToGoSideAction = (int)act[2];
var jumpAction = (int)act[3];
var dirToGoForwardAction = act[0];
var rotateDirAction = act[1];
var dirToGoSideAction = act[2];
var jumpAction = act[3];
if (dirToGoForwardAction == 1)
dirToGo = (largeGrounded ? 1f : 0.5f) * 1f * transform.forward;

jumpingTime -= Time.fixedDeltaTime;
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
if ((!Physics.Raycast(m_AgentRb.position, Vector3.down, 20))
|| (!Physics.Raycast(m_ShortBlockRb.position, Vector3.down, 20)))
{

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
System.Array.Clear(actionsOut, 0, actionsOut.Length);
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
actionsOut[1] = 2f;
discreteActionsOut[1] = 2;
actionsOut[0] = 1f;
discreteActionsOut[0] = 1;
actionsOut[1] = 1f;
discreteActionsOut[1] = 1;
actionsOut[0] = 2f;
discreteActionsOut[0] = 2;
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
discreteActionsOut[3] = Input.GetKey(KeyCode.Space) ? 1 : 0;
}
// Detect when the agent hits the goal

17
Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;

AddReward(1f);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var continuousActions = actionBuffers.ContinuousActions;
bpDict[bodySegment1].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[bodySegment2].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[bodySegment3].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[bodySegment1].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[bodySegment2].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[bodySegment3].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[bodySegment1].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment2].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment3].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment1].SetJointStrength(continuousActions[++i]);
bpDict[bodySegment2].SetJointStrength(continuousActions[++i]);
bpDict[bodySegment3].SetJointStrength(continuousActions[++i]);
// Detect if worm fell off/through platform
if (bodySegment0.position.y < ground.position.y - 2)

2
Project/ProjectSettings/ProjectVersion.txt


m_EditorVersion: 2018.4.17f1
m_EditorVersion: 2018.4.24f1

4
com.unity.ml-agents.extensions/Editor/Unity.ML-Agents.Extensions.Editor.asmdef


{
"name": "Unity.ML-Agents.Extensions.Editor",
"references": [
"Unity.ML-Agents.Extensions"
"Unity.ML-Agents.Extensions",
"Unity.ML-Agents",
"Unity.ML-Agents.Editor"
],
"includePlatforms": [
"Editor"

1
com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs


using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.Editor")]

29
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


return new Pose { rotation = t.rotation, position = t.position };
}
/// <inheritdoc/>
protected internal override Object GetObjectAt(int index)
{
return m_Bodies[index];
}
internal IEnumerable<ArticulationBody> GetEnabledArticulationBodies()
{
if (m_Bodies == null)
{
yield break;
}
for (var i = 0; i < m_Bodies.Length; i++)
{
var articBody = m_Bodies[i];
if (articBody == null)
{
// Ignore a virtual root.
continue;
}
if (IsPoseEnabled(i))
{
yield return articBody;
}
}
}
}
}
#endif // UNITY_2020_1_OR_NEWER

8
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs


var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
numJointObservations += ArticulationBodyJointExtractor.NumObservations(
poseExtractor.Bodies[i], Settings
);
numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}

54
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

string m_SensorName;
PoseExtractor m_PoseExtractor;
IJointExtractor[] m_JointExtractors;
List<IJointExtractor> m_JointExtractors;
/// Construct a new PhysicsBodySensor
/// Construct a new PhysicsBodySensor
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
/// <param name="poseExtractor"></param>
Rigidbody rootBody,
GameObject rootGameObject,
GameObject virtualRoot,
RigidBodyPoseExtractor poseExtractor,
string sensorName=null
string sensorName
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_SensorName = sensorName;
var rigidBodies = poseExtractor.Bodies;
if (rigidBodies != null)
{
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
for (var i = 1; i < rigidBodies.Length; i++)
{
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var rb in poseExtractor.GetEnabledRigidbodies())
m_JointExtractors = new IJointExtractor[0];
var jointExtractor = new RigidBodyJointExtractor(rb);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);

m_Settings = settings;
var numJointExtractorObservations = 0;
var articBodies = poseExtractor.Bodies;
if (articBodies != null)
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
for (var i = 1; i < articBodies.Length; i++)
{
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
var jointExtractor = new ArticulationBodyJointExtractor(articBody);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);

117
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


using System;
using Object = UnityEngine.Object;
namespace Unity.MLAgents.Extensions.Sensors
{

{
if (m_ParentIndices == null)
{
return -1;
throw new NullReferenceException("No parent indices set");
}
return m_ParentIndices[index];

public void SetPoseEnabled(int index, bool val)
{
m_PoseEnabled[index] = val;
}
<