您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
121 行
4.0 KiB
121 行
4.0 KiB
using UnityEngine;
|
|
using MLAgents;
|
|
|
|
public class ReacherAgent : Agent
|
|
{
|
|
|
|
public GameObject pendulumA;
|
|
public GameObject pendulumB;
|
|
public GameObject hand;
|
|
public GameObject goal;
|
|
private ReacherAcademy myAcademy;
|
|
float goalDegree;
|
|
private Rigidbody rbA;
|
|
private Rigidbody rbB;
|
|
// speed of the goal zone around the arm (in radians)
|
|
private float goalSpeed;
|
|
// radius of the goal zone
|
|
private float goalSize;
|
|
// Magnitude of sinusoidal (cosine) deviation of the goal along the vertical dimension
|
|
private float deviation;
|
|
// Frequency of the cosine deviation of the goal along the vertical dimension
|
|
private float deviationFreq;
|
|
|
|
/// <summary>
|
|
/// Collect the rigidbodies of the reacher in order to resue them for
|
|
/// observations and actions.
|
|
/// </summary>
|
|
public override void InitializeAgent()
|
|
{
|
|
rbA = pendulumA.GetComponent<Rigidbody>();
|
|
rbB = pendulumB.GetComponent<Rigidbody>();
|
|
myAcademy = GameObject.Find("Academy").GetComponent<ReacherAcademy>();
|
|
|
|
SetResetParameters();
|
|
}
|
|
|
|
/// <summary>
|
|
/// We collect the normalized rotations, angularal velocities, and velocities of both
|
|
/// limbs of the reacher as well as the relative position of the target and hand.
|
|
/// </summary>
|
|
public override void CollectObservations()
|
|
{
|
|
AddVectorObs(pendulumA.transform.localPosition);
|
|
AddVectorObs(pendulumA.transform.rotation);
|
|
AddVectorObs(rbA.angularVelocity);
|
|
AddVectorObs(rbA.velocity);
|
|
|
|
AddVectorObs(pendulumB.transform.localPosition);
|
|
AddVectorObs(pendulumB.transform.rotation);
|
|
AddVectorObs(rbB.angularVelocity);
|
|
AddVectorObs(rbB.velocity);
|
|
|
|
AddVectorObs(goal.transform.localPosition);
|
|
AddVectorObs(hand.transform.localPosition);
|
|
|
|
AddVectorObs(goalSpeed);
|
|
}
|
|
|
|
/// <summary>
|
|
/// The agent's four actions correspond to torques on each of the two joints.
|
|
/// </summary>
|
|
public override void AgentAction(float[] vectorAction, string textAction)
|
|
{
|
|
goalDegree += goalSpeed;
|
|
UpdateGoalPosition();
|
|
|
|
var torqueX = Mathf.Clamp(vectorAction[0], -1f, 1f) * 150f;
|
|
var torqueZ = Mathf.Clamp(vectorAction[1], -1f, 1f) * 150f;
|
|
rbA.AddTorque(new Vector3(torqueX, 0f, torqueZ));
|
|
|
|
torqueX = Mathf.Clamp(vectorAction[2], -1f, 1f) * 150f;
|
|
torqueZ = Mathf.Clamp(vectorAction[3], -1f, 1f) * 150f;
|
|
rbB.AddTorque(new Vector3(torqueX, 0f, torqueZ));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Used to move the position of the target goal around the agent.
|
|
/// </summary>
|
|
void UpdateGoalPosition()
|
|
{
|
|
var radians = goalDegree * Mathf.PI / 180f;
|
|
var goalX = 8f * Mathf.Cos(radians);
|
|
var goalY = 8f * Mathf.Sin(radians);
|
|
var goalZ = deviation * Mathf.Cos(deviationFreq * radians);
|
|
goal.transform.position = new Vector3(goalY, goalZ, goalX) + transform.position;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Resets the position and velocity of the agent and the goal.
|
|
/// </summary>
|
|
public override void AgentReset()
|
|
{
|
|
pendulumA.transform.position = new Vector3(0f, -4f, 0f) + transform.position;
|
|
pendulumA.transform.rotation = Quaternion.Euler(180f, 0f, 0f);
|
|
rbA.velocity = Vector3.zero;
|
|
rbA.angularVelocity = Vector3.zero;
|
|
|
|
pendulumB.transform.position = new Vector3(0f, -10f, 0f) + transform.position;
|
|
pendulumB.transform.rotation = Quaternion.Euler(180f, 0f, 0f);
|
|
rbB.velocity = Vector3.zero;
|
|
rbB.angularVelocity = Vector3.zero;
|
|
|
|
goalDegree = Random.Range(0, 360);
|
|
UpdateGoalPosition();
|
|
|
|
SetResetParameters();
|
|
|
|
|
|
goal.transform.localScale = new Vector3(goalSize, goalSize, goalSize);
|
|
}
|
|
|
|
|
|
public void SetResetParameters()
|
|
{
|
|
goalSize = myAcademy.resetParameters["goal_size"];
|
|
goalSpeed = Random.Range(-1f, 1f) * myAcademy.resetParameters["goal_speed"];
|
|
deviation = myAcademy.resetParameters["deviation"];
|
|
deviationFreq = myAcademy.resetParameters["deviation_freq"];
|
|
|
|
}
|
|
}
|