浏览代码

Update example environments to use the Actuator API (#4363)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
bb9417f7
共有 24 个文件被更改,包括 281 次插入212 次删除
  1. 9
      DevProject/Assets/ML-Agents/Scripts/Tests/Performance/SensorPerformanceTests.cs
  2. 21
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  3. 9
      Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
  4. 9
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicController.cs
  5. 30
      Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
  6. 37
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  7. 29
      Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
  8. 29
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  9. 23
      Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
  10. 23
      Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
  11. 23
      Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
  12. 12
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
  13. 31
      Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
  14. 4
      Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs
  15. 20
      Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  16. 57
      Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
  17. 31
      Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
  18. 17
      Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs
  19. 5
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  20. 17
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  21. 2
      com.unity.ml-agents/Runtime/Agent.cs
  22. 2
      com.unity.ml-agents/Runtime/Agent.deprecated.cs
  23. 46
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  24. 7
      com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs

9
DevProject/Assets/ML-Agents/Scripts/Tests/Performance/SensorPerformanceTests.cs


using NUnit.Framework;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;

{
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

sensor.AddObservation(new Quaternion(1, 2, 3, 4));
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

[Observable]
public Quaternion QuaternionField = new Quaternion(1, 2, 3, 4);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

get { return m_QuaternionField; }
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}

21
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


using System;
using Unity.MLAgents.Actuators;
using Random = UnityEngine.Random;
public class Ball3DAgent : Agent
{

public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
sensor.AddObservation(m_BallRb.velocity);
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
(gameObject.transform.rotation.z > -0.25f && actionZ < 0f))

SetResetParameters();
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = -Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetAxis("Vertical");
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = -Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetAxis("Vertical");
}
public void SetBall()

9
Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class Ball3DHardAgent : Agent

sensor.AddObservation((ball.transform.position - gameObject.transform.position));
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var actionZ = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
var continuousActions = actionBuffers.ContinuousActions;
var actionZ = 2f * Mathf.Clamp(continuousActions[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(continuousActions[1], -1f, 1f);
if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
(gameObject.transform.rotation.z > -0.25f && actionZ < 0f))

9
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicController.cs


using UnityEngine;
using UnityEngine.SceneManagement;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using UnityEngine.Serialization;
/// <summary>

/// Controls the movement of the GameObject based on the actions received.
/// </summary>
/// <param name="vectorAction"></param>
public void ApplyAction(float[] vectorAction)
public void ApplyAction(ActionSegment<int> vectorAction)
var movement = (int)vectorAction[0];
var movement = vectorAction[0];
var direction = 0;

if (Academy.Instance.IsCommunicatorOn)
{
// Apply the previous step's actions
ApplyAction(m_Agent.GetAction());
ApplyAction(m_Agent.GetStoredActionBuffers().DiscreteActions);
m_Agent?.RequestDecision();
}
else

// Apply the previous step's actions
ApplyAction(m_Agent.GetAction());
ApplyAction(m_Agent.GetStoredActionBuffers().DiscreteActions);
m_TimeSinceDecision = 0f;
m_Agent?.RequestDecision();

30
Project/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class BouncerAgent : Agent

sensor.AddObservation(target.transform.localPosition);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
for (var i = 0; i < vectorAction.Length; i++)
var continuousActions = actionBuffers.ContinuousActions;
for (var i = 0; i < continuousActions.Length; i++)
vectorAction[i] = Mathf.Clamp(vectorAction[i], -1f, 1f);
continuousActions[i] = Mathf.Clamp(continuousActions[i], -1f, 1f);
var x = vectorAction[0];
var y = ScaleAction(vectorAction[1], 0, 1);
var z = vectorAction[2];
var x = continuousActions[0];
var y = ScaleAction(continuousActions[1], 0, 1);
var z = continuousActions[2];
vectorAction[0] * vectorAction[0] +
vectorAction[1] * vectorAction[1] +
vectorAction[2] * vectorAction[2]) / 3f);
continuousActions[0] * continuousActions[0] +
continuousActions[1] * continuousActions[1] +
continuousActions[2] * continuousActions[2]) / 3f);
m_LookDir = new Vector3(x, y, z);
}

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
actionsOut[2] = Input.GetAxis("Vertical");
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
continuousActionsOut[2] = Input.GetAxis("Vertical");
}
void Update()

37
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;

AddReward(1f);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var continuousActions = actionBuffers.ContinuousActions;
bpDict[leg0Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg1Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg2Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg3Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg0Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg1Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg2Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg3Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg0Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg1Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg2Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg3Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[leg0Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg1Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg2Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg3Lower].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[leg0Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg1Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg2Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg3Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg0Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg1Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg2Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg3Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg0Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg1Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg2Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg3Upper].SetJointStrength(continuousActions[++i]);
bpDict[leg0Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg1Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg2Lower].SetJointStrength(continuousActions[++i]);
bpDict[leg3Lower].SetJointStrength(continuousActions[++i]);
}
void FixedUpdate()

29
Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs


using System;
using Unity.MLAgents.Actuators;
using Random = UnityEngine.Random;
public class FoodCollectorAgent : Agent
{

return new Color32(r, g, b, 255);
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
{
m_Shoot = false;

gameObject.GetComponentInChildren<Renderer>().material = normalMaterial;
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0f;
actionsOut[1] = 0f;
actionsOut[2] = 0f;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
discreteActionsOut[1] = 0;
discreteActionsOut[2] = 0;
actionsOut[2] = 2f;
discreteActionsOut[2] = 2;
actionsOut[0] = 1f;
discreteActionsOut[0] = 1;
actionsOut[2] = 1f;
discreteActionsOut[2] = 1;
actionsOut[0] = 2f;
discreteActionsOut[0] = 2;
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
discreteActionsOut[3] = Input.GetKey(KeyCode.Space) ? 1 : 0;
}
public override void OnEpisodeBegin()

29
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using UnityEngine.Serialization;
public class GridAgent : Agent

m_ResetParams = Academy.Instance.EnvironmentParameters;
}
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
// Mask the necessary actions if selected by the user.
if (maskActions)

if (positionX == 0)
{
actionMasker.SetMask(0, new []{ k_Left});
actionMask.WriteMask(0, new []{ k_Left});
actionMasker.SetMask(0, new []{k_Right});
actionMask.WriteMask(0, new []{k_Right});
actionMasker.SetMask(0, new []{k_Down});
actionMask.WriteMask(0, new []{k_Down});
actionMasker.SetMask(0, new []{k_Up});
actionMask.WriteMask(0, new []{k_Up});
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var action = Mathf.FloorToInt(vectorAction[0]);
var action = actionBuffers.DiscreteActions[0];
var targetPos = transform.position;
switch (action)

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = k_NoAction;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = k_NoAction;
actionsOut[0] = k_Right;
discreteActionsOut[0] = k_Right;
actionsOut[0] = k_Up;
discreteActionsOut[0] = k_Up;
actionsOut[0] = k_Left;
discreteActionsOut[0] = k_Left;
actionsOut[0] = k_Down;
discreteActionsOut[0] = k_Down;
}
}

23
Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs


using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class HallwayAgent : Agent

m_GroundRenderer.material = m_GroundMaterial;
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
var action = Mathf.FloorToInt(act[0]);
var action = act[0];
switch (action)
{
case 1:

m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
}
void OnCollisionEnter(Collision col)

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
actionsOut[0] = 3;
discreteActionsOut[0] = 3;
actionsOut[0] = 1;
discreteActionsOut[0] = 1;
actionsOut[0] = 4;
discreteActionsOut[0] = 4;
actionsOut[0] = 2;
discreteActionsOut[0] = 2;
}
}

23
Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs


using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
public class PushAgentBasic : Agent
{

/// <summary>
/// Moves the agent according to the selected action.
/// </summary>
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
var action = Mathf.FloorToInt(act[0]);
var action = act[0];
switch (action)
{

/// <summary>
/// Called every step of the engine. Here the agent takes an action.
/// </summary>
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
actionsOut[0] = 3;
discreteActionsOut[0] = 3;
actionsOut[0] = 1;
discreteActionsOut[0] = 1;
actionsOut[0] = 4;
discreteActionsOut[0] = 4;
actionsOut[0] = 2;
discreteActionsOut[0] = 2;
}
}

23
Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs


using UnityEngine;
using Random = UnityEngine.Random;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class PyramidAgent : Agent

}
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
var action = Mathf.FloorToInt(act[0]);
var action = act[0];
switch (action)
{
case 1:

m_AgentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = 0;
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = 0;
actionsOut[0] = 3;
discreteActionsOut[0] = 3;
actionsOut[0] = 1;
discreteActionsOut[0] = 1;
actionsOut[0] = 4;
discreteActionsOut[0] = 4;
actionsOut[0] = 2;
discreteActionsOut[0] = 2;
}
}

12
Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class ReacherAgent : Agent

/// <summary>
/// The agent's four actions correspond to torques on each of the two joints.
/// </summary>
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var torqueX = Mathf.Clamp(vectorAction[0], -1f, 1f) * 150f;
var torqueZ = Mathf.Clamp(vectorAction[1], -1f, 1f) * 150f;
var torqueX = Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f) * 150f;
var torqueZ = Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f) * 150f;
torqueX = Mathf.Clamp(vectorAction[2], -1f, 1f) * 150f;
torqueZ = Mathf.Clamp(vectorAction[3], -1f, 1f) * 150f;
torqueX = Mathf.Clamp(actionBuffers.ContinuousActions[2], -1f, 1f) * 150f;
torqueZ = Mathf.Clamp(actionBuffers.ContinuousActions[3], -1f, 1f) * 150f;
m_RbB.AddTorque(new Vector3(torqueX, 0f, torqueZ));
}

31
Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs


using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
public class AgentSoccer : Agent

m_ResetParams = Academy.Instance.EnvironmentParameters;
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;

var forwardAxis = (int)act[0];
var rightAxis = (int)act[1];
var rotateAxis = (int)act[2];
var forwardAxis = act[0];
var rightAxis = act[1];
var rotateAxis = act[2];
switch (forwardAxis)
{

ForceMode.VelocityChange);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
{
if (position == Position.Goalie)

// Existential penalty cumulant for Generic
timePenalty -= m_Existential;
}
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
Array.Clear(actionsOut, 0, actionsOut.Length);
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
actionsOut[0] = 1f;
discreteActionsOut[0] = 1;
actionsOut[0] = 2f;
discreteActionsOut[0] = 2;
actionsOut[2] = 1f;
discreteActionsOut[2] = 1;
actionsOut[2] = 2f;
discreteActionsOut[2] = 2;
actionsOut[1] = 1f;
discreteActionsOut[1] = 1;
actionsOut[1] = 2f;
discreteActionsOut[1] = 2;
}
}
/// <summary>

4
Project/Assets/ML-Agents/Examples/Template/Scripts/TemplateAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class TemplateAgent : Agent

}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
{
}

20
Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


using UnityEngine;
using UnityEngine.UI;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class TennisAgent : Agent

sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var moveX = Mathf.Clamp(vectorAction[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(vectorAction[1], -1f, 1f);
var rotate = Mathf.Clamp(vectorAction[2], -1f, 1f) * m_InvertMult;
var continuousActions = actionBuffers.ContinuousActions;
var moveX = Mathf.Clamp(continuousActions[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(continuousActions[1], -1f, 1f);
var rotate = Mathf.Clamp(continuousActions[2], -1f, 1f) * m_InvertMult;
if (moveY > 0.5 && transform.position.y - transform.parent.transform.position.y < -1.5f)
{

m_TextComponent.text = score.ToString();
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = Input.GetAxis("Horizontal"); // Racket Movement
actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
actionsOut[2] = Input.GetAxis("Vertical"); // Racket Rotation
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Horizontal"); // Racket Movement
continuousActionsOut[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
continuousActionsOut[2] = Input.GetAxis("Vertical"); // Racket Rotation
}
public override void OnEpisodeBegin()

57
Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;
using BodyPart = Unity.MLAgentsExamples.BodyPart;

}
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
bpDict[chest].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
bpDict[spine].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
var continuousActions = actionBuffers.ContinuousActions;
bpDict[chest].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[spine].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[thighL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[thighR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[shinL].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[shinR].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[footR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
bpDict[footL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]);
bpDict[thighL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[thighR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[shinL].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[shinR].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[footR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[footL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], continuousActions[++i]);
bpDict[armL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[armR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[forearmL].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[forearmR].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[head].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[armL].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[armR].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[forearmL].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[forearmR].SetJointTargetRotation(continuousActions[++i], 0, 0);
bpDict[head].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[chest].SetJointStrength(vectorAction[++i]);
bpDict[spine].SetJointStrength(vectorAction[++i]);
bpDict[head].SetJointStrength(vectorAction[++i]);
bpDict[thighL].SetJointStrength(vectorAction[++i]);
bpDict[shinL].SetJointStrength(vectorAction[++i]);
bpDict[footL].SetJointStrength(vectorAction[++i]);
bpDict[thighR].SetJointStrength(vectorAction[++i]);
bpDict[shinR].SetJointStrength(vectorAction[++i]);
bpDict[footR].SetJointStrength(vectorAction[++i]);
bpDict[armL].SetJointStrength(vectorAction[++i]);
bpDict[forearmL].SetJointStrength(vectorAction[++i]);
bpDict[armR].SetJointStrength(vectorAction[++i]);
bpDict[forearmR].SetJointStrength(vectorAction[++i]);
bpDict[chest].SetJointStrength(continuousActions[++i]);
bpDict[spine].SetJointStrength(continuousActions[++i]);
bpDict[head].SetJointStrength(continuousActions[++i]);
bpDict[thighL].SetJointStrength(continuousActions[++i]);
bpDict[shinL].SetJointStrength(continuousActions[++i]);
bpDict[footL].SetJointStrength(continuousActions[++i]);
bpDict[thighR].SetJointStrength(continuousActions[++i]);
bpDict[shinR].SetJointStrength(continuousActions[++i]);
bpDict[footR].SetJointStrength(continuousActions[++i]);
bpDict[armL].SetJointStrength(continuousActions[++i]);
bpDict[forearmL].SetJointStrength(continuousActions[++i]);
bpDict[armR].SetJointStrength(continuousActions[++i]);
bpDict[forearmR].SetJointStrength(continuousActions[++i]);
}
//Update OrientationCube and DirectionIndicator

31
Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.Barracuda;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgentsExamples;

m_GroundRenderer.material = m_GroundMaterial;
}
public void MoveAgent(float[] act)
public void MoveAgent(ActionSegment<int> act)
{
AddReward(-0.0005f);
var smallGrounded = DoGroundCheck(true);

var rotateDir = Vector3.zero;
var dirToGoForwardAction = (int)act[0];
var rotateDirAction = (int)act[1];
var dirToGoSideAction = (int)act[2];
var jumpAction = (int)act[3];
var dirToGoForwardAction = act[0];
var rotateDirAction = act[1];
var dirToGoSideAction = act[2];
var jumpAction = act[3];
if (dirToGoForwardAction == 1)
dirToGo = (largeGrounded ? 1f : 0.5f) * 1f * transform.forward;

jumpingTime -= Time.fixedDeltaTime;
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
MoveAgent(vectorAction);
MoveAgent(actionBuffers.DiscreteActions);
if ((!Physics.Raycast(m_AgentRb.position, Vector3.down, 20))
|| (!Physics.Raycast(m_ShortBlockRb.position, Vector3.down, 20)))
{

}
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
System.Array.Clear(actionsOut, 0, actionsOut.Length);
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
actionsOut[1] = 2f;
discreteActionsOut[1] = 2;
actionsOut[0] = 1f;
discreteActionsOut[0] = 1;
actionsOut[1] = 1f;
discreteActionsOut[1] = 1;
actionsOut[0] = 2f;
discreteActionsOut[0] = 2;
actionsOut[3] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
discreteActionsOut[3] = Input.GetKey(KeyCode.Space) ? 1 : 0;
}
// Detect when the agent hits the goal

17
Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgentsExamples;
using Unity.MLAgents.Sensors;

target.position = newTargetPos + ground.position;
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers actionBuffers)
var continuousActions = actionBuffers.ContinuousActions;
bpDict[bodySegment1].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[bodySegment2].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[bodySegment3].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[bodySegment1].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[bodySegment2].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[bodySegment3].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0);
bpDict[bodySegment1].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment2].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment3].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment1].SetJointStrength(continuousActions[++i]);
bpDict[bodySegment2].SetJointStrength(continuousActions[++i]);
bpDict[bodySegment3].SetJointStrength(continuousActions[++i]);
if (bodySegment0.position.y < ground.position.y -2)
{

5
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
namespace Unity.MLAgents.Actuators
{

/// </summary>
/// <param name="actionArray">The action array to use for the this segment.</param>
public ActionSegment(T[] actionArray)
: this(actionArray ?? System.Array.Empty<T>(), 0, actionArray?.Length ?? 0) { }
: this(actionArray ?? System.Array.Empty<T>(), 0, actionArray?.Length ?? 0) {}
/// <summary>
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array

/// <inheritdoc cref="IEquatable{T}.Equals(T)"/>
public bool Equals(ActionSegment<T> other)
{
return Offset == other.Offset && Length == other.Length && Equals(Array, other.Array);
return Offset == other.Offset && Length == other.Length && Array.SequenceEqual(other.Array);
}
/// <inheritdoc cref="ValueType.GetHashCode"/>

17
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask;
/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// The currently stored <see cref="ActionBuffers"/> object for the <see cref="IActuator"/>s managed by this class.
// public float[] StoredContinuousActions { get; private set; }
/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
// public int[] StoredDiscreteActions { get; private set; }
public ActionBuffers StoredActions { get; private set; }
/// <summary>

/// Updates the local action buffer with the action buffer passed in. If the buffer
/// passed in is null, the local action buffer will be cleared.
/// </summary>
/// <param name="continuousActionBuffer">The action buffer which contains all of the
/// continuous actions for the IActuators in this list.</param>
/// <param name="discreteActionBuffer">The action buffer which contains all of the
/// discrete actions for the IActuators in this list.</param>
/// <param name="actions">The <see cref="ActionBuffers"/> object which contains all of the
/// actions for the IActuators in this list.</param>
public void UpdateActions(ActionBuffers actions)
{
ReadyActuatorsForExecution();

/// <summary>
/// Sorts the <see cref="IActuator"/>s according to their <see cref="IActuator.GetName"/> value.
/// Sorts the <see cref="IActuator"/>s according to their <see cref="IActuator.Name"/> value.
/// </summary>
void SortActuators()
{

2
com.unity.ml-agents/Runtime/Agent.cs


/// <summary>
/// Gets the last ActionBuffer for this agent.
/// </summary>
public ActionBuffers GetStoredContinuousActions()
public ActionBuffers GetStoredActionBuffers()
{
return m_ActuatorManager.StoredActions;
}

2
com.unity.ml-agents/Runtime/Agent.deprecated.cs


/// The last action that was decided by the Agent (or null if no decision has been made).
/// </returns>
/// <seealso cref="OnActionReceived(float[])"/>
// [Obsolete("GetAction has been deprecated, please use GetStoredContinuousActions, Or GetStoredDiscreteActions.")]
// [Obsolete("GetAction has been deprecated, please use GetStoredActionBuffers, Or GetStoredDiscreteActions.")]
public float[] GetAction()
{
return m_Info.storedVectorActions;

46
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


{
public Action OnRequestDecision;
ObservationWriter m_ObsWriter = new ObservationWriter();
static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(Array.Empty<float>(), Array.Empty<int>());
static ActionSpec s_ActionSpec = ActionSpec.MakeContinuous(1);
static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(new float[1], Array.Empty<int>());
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
foreach (var sensor in sensors)

sensor.AddObservation(collectObservationsCallsForEpisode);
}
public override void OnActionReceived(float[] vectorAction)
public override void OnActionReceived(ActionBuffers buffers)
{
agentActionCalls += 1;
agentActionCallsForEpisode += 1;

agentActionCallsForEpisode = 0;
}
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
actionsOut[0] = obs[0];
var continuousActions = actionsOut.ContinuousActions;
continuousActions[0] = (int)obs[0];
heuristicCalls++;
}
}

public void TestAgent()
{
var agentGo1 = new GameObject("TestAgent");
var bp1 = agentGo1.AddComponent<BehaviorParameters>();
bp1.BrainParameters.VectorActionSize = new[] { 1 };
bp1.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
var bp2 = agentGo2.AddComponent<BehaviorParameters>();
bp2.BrainParameters.VectorActionSize = new[] { 1 };
bp2.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();

public void TestAgent()
{
var agentGo1 = new GameObject("TestAgent");
var bp1 = agentGo1.AddComponent<BehaviorParameters>();
bp1.BrainParameters.VectorActionSize = new[] { 1 };
bp1.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
var bp2 = agentGo2.AddComponent<BehaviorParameters>();
bp2.BrainParameters.VectorActionSize = new[] { 1 };
bp2.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();

public void AssertStackingReset()
{
var agentGo1 = new GameObject("TestAgent");
agentGo1.AddComponent<TestAgent>();
var bp1 = agentGo1.AddComponent<BehaviorParameters>();
bp1.BrainParameters.VectorActionSize = new[] { 1 };
bp1.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
var agent1 = agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var aca = Academy.Instance;
agent1.LazyInitialize();
var policy = new TestPolicy();

public void TestCumulativeReward()
{
var agentGo1 = new GameObject("TestAgent");
agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var bp1 = agentGo1.AddComponent<BehaviorParameters>();
bp1.BrainParameters.VectorActionSize = new[] { 1 };
bp1.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
var agent1 = agentGo1.AddComponent<TestAgent>();
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
var bp2 = agentGo2.AddComponent<BehaviorParameters>();
bp2.BrainParameters.VectorActionSize = new[] { 1 };
bp2.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
var agent2 = agentGo2.AddComponent<TestAgent>();
var aca = Academy.Instance;
var decisionRequester = agent1.gameObject.AddComponent<DecisionRequester>();

public void TestMaxStepsReset()
{
var agentGo1 = new GameObject("TestAgent");
var bp1 = agentGo1.AddComponent<BehaviorParameters>();
bp1.BrainParameters.VectorActionSize = new[] { 1 };
bp1.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var aca = Academy.Instance;

{
// Make sure that Agents with HeuristicPolicies step their sensors each Academy step.
var agentGo1 = new GameObject("TestAgent");
var bp1 = agentGo1.AddComponent<BehaviorParameters>();
bp1.BrainParameters.VectorActionSize = new[] { 1 };
bp1.BrainParameters.VectorActionSpaceType = SpaceType.Continuous;
agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var aca = Academy.Instance;

7
com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs


using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using UnityEngine;
using UnityEngine.TestTools;

[Observable]
public float ObservableFloat;
public override void Heuristic(float[] actionsOut)
public override void Heuristic(in ActionBuffers actionsOut)
{
numHeuristicCalls++;
base.Heuristic(actionsOut);

Academy.Instance.EnvironmentStep();
var actions = agent.GetAction();
var actions = agent.GetStoredActionBuffers().DiscreteActions;
Assert.AreEqual(new[] {0.0f, 0.0f}, actions);
Assert.AreEqual(new ActionSegment<int>(new[] {0, 0}), actions);
Assert.AreEqual(1, agent.numHeuristicCalls);
Academy.Instance.EnvironmentStep();

正在加载...
取消
保存