using UnityEngine; using MLAgents; public class ReacherAgent : Agent { public GameObject pendulumA; public GameObject pendulumB; public GameObject hand; public GameObject goal; private ReacherAcademy myAcademy; float goalDegree; private Rigidbody rbA; private Rigidbody rbB; // speed of the goal zone around the arm (in radians) private float goalSpeed; // radius of the goal zone private float goalSize; // Magnitude of sinusoidal (cosine) deviation of the goal along the vertical dimension private float deviation; // Frequency of the cosine deviation of the goal along the vertical dimension private float deviationFreq; /// /// Collect the rigidbodies of the reacher in order to resue them for /// observations and actions. /// public override void InitializeAgent() { rbA = pendulumA.GetComponent(); rbB = pendulumB.GetComponent(); myAcademy = GameObject.Find("Academy").GetComponent(); SetResetParameters(); } /// /// We collect the normalized rotations, angularal velocities, and velocities of both /// limbs of the reacher as well as the relative position of the target and hand. /// public override void CollectObservations() { AddVectorObs(pendulumA.transform.localPosition); AddVectorObs(pendulumA.transform.rotation); AddVectorObs(rbA.angularVelocity); AddVectorObs(rbA.velocity); AddVectorObs(pendulumB.transform.localPosition); AddVectorObs(pendulumB.transform.rotation); AddVectorObs(rbB.angularVelocity); AddVectorObs(rbB.velocity); AddVectorObs(goal.transform.localPosition); AddVectorObs(hand.transform.localPosition); AddVectorObs(goalSpeed); } /// /// The agent's four actions correspond to torques on each of the two joints. /// public override void AgentAction(float[] vectorAction, string textAction) { goalDegree += goalSpeed; UpdateGoalPosition(); var torqueX = Mathf.Clamp(vectorAction[0], -1f, 1f) * 150f; var torqueZ = Mathf.Clamp(vectorAction[1], -1f, 1f) * 150f; rbA.AddTorque(new Vector3(torqueX, 0f, torqueZ)); torqueX = Mathf.Clamp(vectorAction[2], -1f, 1f) * 150f; torqueZ = Mathf.Clamp(vectorAction[3], -1f, 1f) * 150f; rbB.AddTorque(new Vector3(torqueX, 0f, torqueZ)); } /// /// Used to move the position of the target goal around the agent. /// void UpdateGoalPosition() { var radians = goalDegree * Mathf.PI / 180f; var goalX = 8f * Mathf.Cos(radians); var goalY = 8f * Mathf.Sin(radians); var goalZ = deviation * Mathf.Cos(deviationFreq * radians); goal.transform.position = new Vector3(goalY, goalZ, goalX) + transform.position; } /// /// Resets the position and velocity of the agent and the goal. /// public override void AgentReset() { pendulumA.transform.position = new Vector3(0f, -4f, 0f) + transform.position; pendulumA.transform.rotation = Quaternion.Euler(180f, 0f, 0f); rbA.velocity = Vector3.zero; rbA.angularVelocity = Vector3.zero; pendulumB.transform.position = new Vector3(0f, -10f, 0f) + transform.position; pendulumB.transform.rotation = Quaternion.Euler(180f, 0f, 0f); rbB.velocity = Vector3.zero; rbB.angularVelocity = Vector3.zero; goalDegree = Random.Range(0, 360); UpdateGoalPosition(); SetResetParameters(); goal.transform.localScale = new Vector3(goalSize, goalSize, goalSize); } public void SetResetParameters() { goalSize = myAcademy.resetParameters["goal_size"]; goalSpeed = Random.Range(-1f, 1f) * myAcademy.resetParameters["goal_speed"]; deviation = myAcademy.resetParameters["deviation"]; deviationFreq = myAcademy.resetParameters["deviation_freq"]; } }