using UnityEngine;
using MLAgents;
public class ReacherAgent : Agent {
public GameObject pendulumA;
public GameObject pendulumB;
public GameObject hand;
public GameObject goal;
private ReacherAcademy myAcademy;
float goalDegree;
private Rigidbody rbA;
private Rigidbody rbB;
private float goalSpeed;
private float goalSize;
///
/// Collect the rigidbodies of the reacher in order to resue them for
/// observations and actions.
///
public override void InitializeAgent()
{
rbA = pendulumA.GetComponent();
rbB = pendulumB.GetComponent();
myAcademy = GameObject.Find("Academy").GetComponent();
}
///
/// We collect the normalized rotations, angularal velocities, and velocities of both
/// limbs of the reacher as well as the relative position of the target and hand.
///
public override void CollectObservations()
{
AddVectorObs(pendulumA.transform.localPosition);
AddVectorObs(pendulumA.transform.rotation);
AddVectorObs(rbA.angularVelocity);
AddVectorObs(rbA.velocity);
AddVectorObs(pendulumB.transform.localPosition);
AddVectorObs(pendulumB.transform.rotation);
AddVectorObs(rbB.angularVelocity);
AddVectorObs(rbB.velocity);
AddVectorObs(goal.transform.localPosition);
AddVectorObs(hand.transform.localPosition);
AddVectorObs(goalSpeed);
}
///
/// The agent's four actions correspond to torques on each of the two joints.
///
public override void AgentAction(float[] vectorAction, string textAction)
{
goalDegree += goalSpeed;
UpdateGoalPosition();
var torqueX = Mathf.Clamp(vectorAction[0], -1f, 1f) * 150f;
var torqueZ = Mathf.Clamp(vectorAction[1], -1f, 1f) * 150f;
rbA.AddTorque(new Vector3(torqueX, 0f, torqueZ));
torqueX = Mathf.Clamp(vectorAction[2], -1f, 1f) * 150f;
torqueZ = Mathf.Clamp(vectorAction[3], -1f, 1f) * 150f;
rbB.AddTorque(new Vector3(torqueX, 0f, torqueZ));
}
///
/// Used to move the position of the target goal around the agent.
///
void UpdateGoalPosition()
{
var radians = goalDegree * Mathf.PI / 180f;
var goalX = 8f * Mathf.Cos(radians);
var goalY = 8f * Mathf.Sin(radians);
goal.transform.position = new Vector3(goalY, -1f, goalX) + transform.position;
}
///
/// Resets the position and velocity of the agent and the goal.
///
public override void AgentReset()
{
pendulumA.transform.position = new Vector3(0f, -4f, 0f) + transform.position;
pendulumA.transform.rotation = Quaternion.Euler(180f, 0f, 0f);
rbA.velocity = Vector3.zero;
rbA.angularVelocity = Vector3.zero;
pendulumB.transform.position = new Vector3(0f, -10f, 0f) + transform.position;
pendulumB.transform.rotation = Quaternion.Euler(180f, 0f, 0f);
rbB.velocity = Vector3.zero;
rbB.angularVelocity = Vector3.zero;
goalDegree = Random.Range(0, 360);
UpdateGoalPosition();
goalSize = myAcademy.goalSize;
goalSpeed = Random.Range(-1f, 1f) * myAcademy.goalSpeed;
goal.transform.localScale = new Vector3(goalSize, goalSize, goalSize);
}
}