using UnityEngine; using Unity.MLAgents; using Unity.MLAgentsExamples; using Unity.MLAgents.Sensors; [RequireComponent(typeof(JointDriveController))] // Required to set joint forces public class CrawlerAgent : Agent { public float maximumWalkingSpeed = 999; //The max walk velocity magnitude an agent will be rewarded for Vector3 m_WalkDir; //Direction to the target Quaternion m_WalkDirLookRot; //Will hold the rotation to our target [Header("Target To Walk Towards")] [Space(10)] public Transform target; //Target the agent will walk towards. public float targetSpawnRadius; //The radius in which a target can be randomly spawned. public bool detectTargets; //Should this agent detect targets public bool respawnTargetWhenTouched; //Should the target respawn to a different position when touched public Transform ground; //Ground gameobject. The height will be used for target spawning [Header("Body Parts")] [Space(10)] public Transform body; public Transform leg0Upper; public Transform leg0Lower; public Transform leg1Upper; public Transform leg1Lower; public Transform leg2Upper; public Transform leg2Lower; public Transform leg3Upper; public Transform leg3Lower; [Header("Orientation")] [Space(10)] //This will be used as a stable reference point for observations //Because ragdolls can move erratically, using a standalone reference point can significantly improve learning public GameObject orientationCube; JointDriveController m_JdController; [Header("Reward Functions To Use")] [Space(10)] public bool rewardMovingTowardsTarget; // Agent should move towards target public bool rewardFacingTarget; // Agent should face the target public bool rewardUseTimePenalty; // Hurry up [Header("Foot Grounded Visualization")] [Space(10)] public bool useFootGroundedVisualization; public MeshRenderer foot0; public MeshRenderer foot1; public MeshRenderer foot2; public MeshRenderer foot3; public Material groundedMaterial; public Material unGroundedMaterial; EnvironmentParameters m_ResetParams; public override void Initialize() { m_ResetParams = Academy.Instance.EnvironmentParameters; UpdateOrientationCube(); m_JdController = GetComponent(); SetResetParameters(); //Setup each body part m_JdController.SetupBodyPart(body); m_JdController.SetupBodyPart(leg0Upper); m_JdController.SetupBodyPart(leg0Lower); m_JdController.SetupBodyPart(leg1Upper); m_JdController.SetupBodyPart(leg1Lower); m_JdController.SetupBodyPart(leg2Upper); m_JdController.SetupBodyPart(leg2Lower); m_JdController.SetupBodyPart(leg3Upper); m_JdController.SetupBodyPart(leg3Lower); } /// /// Add relevant information on each body part to observations. /// public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) { sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground //Get velocities in the context of our orientation cube's space //Note: You can get these velocities in world space as well but it may not train as well. sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.velocity)); sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity)); //Get position relative to hips in the context of our orientation cube's space sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.position - body.position)); if (bp.rb.transform != body) { sensor.AddObservation(bp.rb.transform.localRotation); sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit); } } /// /// Loop over body parts to add them to observation. /// public override void CollectObservations(VectorSensor sensor) { sensor.AddObservation(Quaternion.FromToRotation(body.forward, orientationCube.transform.forward)); //Add pos of target relative to orientation cube sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.position)); RaycastHit hit; float maxRaycastDist = 10; if (Physics.Raycast(body.position, Vector3.down, out hit, maxRaycastDist)) { sensor.AddObservation(hit.distance / maxRaycastDist); } else sensor.AddObservation(1); foreach (var bodyPart in m_JdController.bodyPartsList) { CollectObservationBodyPart(bodyPart, sensor); } } /// /// Agent touched the target /// public void TouchedTarget() { AddReward(1f); if (respawnTargetWhenTouched) { GetRandomTargetPos(); } } /// /// Moves target to a random position within specified radius. /// public void GetRandomTargetPos() { var newTargetPos = Random.insideUnitSphere * targetSpawnRadius; newTargetPos.y = 5; target.position = newTargetPos + ground.position; } public override void OnActionReceived(float[] vectorAction) { // The dictionary with all the body parts in it are in the jdController var bpDict = m_JdController.bodyPartsDict; var i = -1; // Pick a new target joint rotation bpDict[leg0Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[leg1Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[leg2Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[leg3Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[leg0Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); bpDict[leg1Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); bpDict[leg2Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); bpDict[leg3Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); // Update joint strength bpDict[leg0Upper].SetJointStrength(vectorAction[++i]); bpDict[leg1Upper].SetJointStrength(vectorAction[++i]); bpDict[leg2Upper].SetJointStrength(vectorAction[++i]); bpDict[leg3Upper].SetJointStrength(vectorAction[++i]); bpDict[leg0Lower].SetJointStrength(vectorAction[++i]); bpDict[leg1Lower].SetJointStrength(vectorAction[++i]); bpDict[leg2Lower].SetJointStrength(vectorAction[++i]); bpDict[leg3Lower].SetJointStrength(vectorAction[++i]); } void UpdateOrientationCube() { //FACING DIR m_WalkDir = target.position - orientationCube.transform.position; m_WalkDir.y = 0; //flatten dir on the y m_WalkDirLookRot = Quaternion.LookRotation(m_WalkDir); //get our look rot to the target //UPDATE ORIENTATION CUBE POS & ROT orientationCube.transform.position = body.position; orientationCube.transform.rotation = m_WalkDirLookRot; } void FixedUpdate() { if (detectTargets) { foreach (var bodyPart in m_JdController.bodyPartsList) { if (bodyPart.targetContact && bodyPart.targetContact.touchingTarget) { TouchedTarget(); } } } UpdateOrientationCube(); // If enabled the feet will light up green when the foot is grounded. // This is just a visualization and isn't necessary for function if (useFootGroundedVisualization) { foot0.material = m_JdController.bodyPartsDict[leg0Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; foot1.material = m_JdController.bodyPartsDict[leg1Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; foot2.material = m_JdController.bodyPartsDict[leg2Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; foot3.material = m_JdController.bodyPartsDict[leg3Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; } // Set reward for this step according to mixture of the following elements. if (rewardMovingTowardsTarget) { RewardFunctionMovingTowards(); } if (rewardFacingTarget) { RewardFunctionFacingTarget(); } if (rewardUseTimePenalty) { RewardFunctionTimePenalty(); } } /// /// Reward moving towards target & Penalize moving away from target. /// void RewardFunctionMovingTowards() { var movingTowardsDot = Vector3.Dot(orientationCube.transform.forward, Vector3.ClampMagnitude(m_JdController.bodyPartsDict[body].rb.velocity, maximumWalkingSpeed)); ; AddReward(0.03f * movingTowardsDot); } /// /// Reward facing target & Penalize facing away from target /// void RewardFunctionFacingTarget() { AddReward(0.01f * Vector3.Dot(orientationCube.transform.forward, body.forward)); } /// /// Existential penalty for time-contrained tasks. /// void RewardFunctionTimePenalty() { AddReward(0.0001f); } /// /// Loop over body parts and reset them to initial conditions. /// public override void OnEpisodeBegin() { SetResetParameters(); foreach (var bodyPart in m_JdController.bodyPartsDict.Values) { bodyPart.Reset(bodyPart); } //Random start rotation to help generalize transform.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0); UpdateOrientationCube(); if (detectTargets && respawnTargetWhenTouched) { GetRandomTargetPos(); } } public void SetResetParameters() { m_JdController.jointDampen = m_ResetParams.GetWithDefault("dampen", 3000f); } private void OnDrawGizmosSelected() { if (Application.isPlaying) { Gizmos.color = Color.green; Gizmos.matrix = orientationCube.transform.localToWorldMatrix; Gizmos.DrawWireCube(Vector3.zero, orientationCube.transform.localScale); Gizmos.DrawRay(Vector3.zero, Vector3.forward); } } }