using UnityEngine; using MLAgents; public class WalkerAgent : Agent { [Header("Specific to Walker")] [Header("Target To Walk Towards")] [Space(10)] public Transform target; Vector3 m_DirToTarget; public Transform hips; public Transform chest; public Transform spine; public Transform head; public Transform thighL; public Transform shinL; public Transform footL; public Transform thighR; public Transform shinR; public Transform footR; public Transform armL; public Transform forearmL; public Transform handL; public Transform armR; public Transform forearmR; public Transform handR; JointDriveController m_JdController; bool m_IsNewDecisionStep; int m_CurrentDecisionStep; Rigidbody m_HipsRb; Rigidbody m_ChestRb; Rigidbody m_SpineRb; IFloatProperties m_ResetParams; public override void InitializeAgent() { m_JdController = GetComponent(); m_JdController.SetupBodyPart(hips); m_JdController.SetupBodyPart(chest); m_JdController.SetupBodyPart(spine); m_JdController.SetupBodyPart(head); m_JdController.SetupBodyPart(thighL); m_JdController.SetupBodyPart(shinL); m_JdController.SetupBodyPart(footL); m_JdController.SetupBodyPart(thighR); m_JdController.SetupBodyPart(shinR); m_JdController.SetupBodyPart(footR); m_JdController.SetupBodyPart(armL); m_JdController.SetupBodyPart(forearmL); m_JdController.SetupBodyPart(handL); m_JdController.SetupBodyPart(armR); m_JdController.SetupBodyPart(forearmR); m_JdController.SetupBodyPart(handR); m_HipsRb = hips.GetComponent(); m_ChestRb = chest.GetComponent(); m_SpineRb = spine.GetComponent(); var academy = FindObjectOfType(); m_ResetParams = academy.FloatProperties; SetResetParameters(); } /// /// Add relevant information on each body part to observations. /// public void CollectObservationBodyPart(BodyPart bp) { var rb = bp.rb; AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Is this bp touching the ground AddVectorObs(rb.velocity); AddVectorObs(rb.angularVelocity); var localPosRelToHips = hips.InverseTransformPoint(rb.position); AddVectorObs(localPosRelToHips); if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR && bp.rb.transform != footL && bp.rb.transform != footR && bp.rb.transform != head) { AddVectorObs(bp.currentXNormalizedRot); AddVectorObs(bp.currentYNormalizedRot); AddVectorObs(bp.currentZNormalizedRot); AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit); } } /// /// Loop over body parts to add them to observation. /// public override void CollectObservations() { m_JdController.GetCurrentJointForces(); AddVectorObs(m_DirToTarget.normalized); AddVectorObs(m_JdController.bodyPartsDict[hips].rb.position); AddVectorObs(hips.forward); AddVectorObs(hips.up); foreach (var bodyPart in m_JdController.bodyPartsDict.Values) { CollectObservationBodyPart(bodyPart); } } public override void AgentAction(float[] vectorAction) { m_DirToTarget = target.position - m_JdController.bodyPartsDict[hips].rb.position; // Apply action to all relevant body parts. if (m_IsNewDecisionStep) { var bpDict = m_JdController.bodyPartsDict; var i = -1; bpDict[chest].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]); bpDict[spine].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]); bpDict[thighL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[thighR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[shinL].SetJointTargetRotation(vectorAction[++i], 0, 0); bpDict[shinR].SetJointTargetRotation(vectorAction[++i], 0, 0); bpDict[footR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]); bpDict[footL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]); bpDict[armL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[armR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); bpDict[forearmL].SetJointTargetRotation(vectorAction[++i], 0, 0); bpDict[forearmR].SetJointTargetRotation(vectorAction[++i], 0, 0); bpDict[head].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); //update joint strength settings bpDict[chest].SetJointStrength(vectorAction[++i]); bpDict[spine].SetJointStrength(vectorAction[++i]); bpDict[head].SetJointStrength(vectorAction[++i]); bpDict[thighL].SetJointStrength(vectorAction[++i]); bpDict[shinL].SetJointStrength(vectorAction[++i]); bpDict[footL].SetJointStrength(vectorAction[++i]); bpDict[thighR].SetJointStrength(vectorAction[++i]); bpDict[shinR].SetJointStrength(vectorAction[++i]); bpDict[footR].SetJointStrength(vectorAction[++i]); bpDict[armL].SetJointStrength(vectorAction[++i]); bpDict[forearmL].SetJointStrength(vectorAction[++i]); bpDict[armR].SetJointStrength(vectorAction[++i]); bpDict[forearmR].SetJointStrength(vectorAction[++i]); } IncrementDecisionTimer(); // Set reward for this step according to mixture of the following elements. // a. Velocity alignment with goal direction. // b. Rotation alignment with goal direction. // c. Encourage head height. // d. Discourage head movement. AddReward( +0.03f * Vector3.Dot(m_DirToTarget.normalized, m_JdController.bodyPartsDict[hips].rb.velocity) + 0.01f * Vector3.Dot(m_DirToTarget.normalized, hips.forward) + 0.02f * (head.position.y - hips.position.y) - 0.01f * Vector3.Distance(m_JdController.bodyPartsDict[head].rb.velocity, m_JdController.bodyPartsDict[hips].rb.velocity) ); } /// /// Only change the joint settings based on decision frequency. /// public void IncrementDecisionTimer() { if (m_CurrentDecisionStep == agentParameters.numberOfActionsBetweenDecisions || agentParameters.numberOfActionsBetweenDecisions == 1) { m_CurrentDecisionStep = 1; m_IsNewDecisionStep = true; } else { m_CurrentDecisionStep++; m_IsNewDecisionStep = false; } } /// /// Loop over body parts and reset them to initial conditions. /// public override void AgentReset() { if (m_DirToTarget != Vector3.zero) { transform.rotation = Quaternion.LookRotation(m_DirToTarget); } foreach (var bodyPart in m_JdController.bodyPartsDict.Values) { bodyPart.Reset(bodyPart); } m_IsNewDecisionStep = true; m_CurrentDecisionStep = 1; SetResetParameters(); } public void SetTorsoMass() { m_ChestRb.mass = m_ResetParams.GetPropertyWithDefault("chest_mass", 8); m_SpineRb.mass = m_ResetParams.GetPropertyWithDefault("spine_mass", 10); m_HipsRb.mass = m_ResetParams.GetPropertyWithDefault("hip_mass", 15); } public void SetResetParameters() { SetTorsoMass(); } }