using System; using UnityEngine; using Unity.MLAgents; using Unity.Barracuda; using Unity.MLAgents.Actuators; using Unity.MLAgentsExamples; using Unity.MLAgents.Sensors; using Random = UnityEngine.Random; [RequireComponent(typeof(JointDriveController))] // Required to set joint forces public class CrawlerAgent : Agent { public enum CrawlerAgentBehaviorType { CrawlerDynamic, CrawlerDynamicVariableSpeed, CrawlerStatic, CrawlerStaticVariableSpeed } public CrawlerAgentBehaviorType typeOfCrawler; //Crawler Brains //A different brain will be used depending on the CrawlerAgentBehaviorType selected [Header("NN Models")] public NNModel crawlerDyBrain; public NNModel crawlerDyVSBrain; public NNModel crawlerStBrain; public NNModel crawlerStVSBrain; [Header("Walk Speed")] [Range(0.1f, 10)] [SerializeField] //The walking speed to try and achieve private float m_TargetWalkingSpeed = 10; public float TargetWalkingSpeed // property { get { return m_TargetWalkingSpeed; } set { m_TargetWalkingSpeed = Mathf.Clamp(value, .1f, m_maxWalkingSpeed); } } const float m_maxWalkingSpeed = 10; //The max walking speed //Should the agent sample a new goal velocity each episode? //If true, walkSpeed will be randomly set between zero and m_maxWalkingSpeed in OnEpisodeBegin() //If false, the goal velocity will be walkingSpeed public bool randomizeWalkSpeedEachEpisode; //The direction an agent will walk during training. private Vector3 m_WorldDirToWalk = Vector3.right; [Header("Target To Walk Towards")] public Transform targetPrefab; //Target prefab public Transform target; //Target the agent will walk towards during training. [Header("Body Parts")] [Space(10)] public Transform body; public Transform leg0Upper; public Transform leg0Lower; public Transform leg1Upper; public Transform leg1Lower; public Transform leg2Upper; public Transform leg2Lower; public Transform leg3Upper; public Transform leg3Lower; //This will be used as a stabilized model space reference point for observations //Because ragdolls can move erratically during training, using a stabilized reference transform improves learning OrientationCubeController m_OrientationCube; //The indicator graphic gameobject that points towards the target DirectionIndicator m_DirectionIndicator; JointDriveController m_JdController; [Header("Foot Grounded Visualization")] [Space(10)] public bool useFootGroundedVisualization; public MeshRenderer foot0; public MeshRenderer foot1; public MeshRenderer foot2; public MeshRenderer foot3; public Material groundedMaterial; public Material unGroundedMaterial; private Unity.MLAgents.Policies.BehaviorParameters m_BehaviorParams; public override void Initialize() { m_BehaviorParams = GetComponent(); switch (typeOfCrawler) { case CrawlerAgentBehaviorType.CrawlerDynamic : { m_BehaviorParams.BehaviorName = "CrawlerDynamic"; if(crawlerDyBrain) m_BehaviorParams.Model = crawlerDyBrain; randomizeWalkSpeedEachEpisode = false; target = Instantiate(targetPrefab, transform.position, Quaternion.identity, transform); break; } case CrawlerAgentBehaviorType.CrawlerDynamicVariableSpeed : { m_BehaviorParams.BehaviorName = "CrawlerDynamicVariableSpeed"; if(crawlerDyVSBrain) m_BehaviorParams.Model = crawlerDyVSBrain; // SetModel("CrawlerDynamicVariableSpeed", crawlerDyVSBrain); target = Instantiate(targetPrefab, transform.position, Quaternion.identity, transform); randomizeWalkSpeedEachEpisode = true; break; } case CrawlerAgentBehaviorType.CrawlerStatic : { m_BehaviorParams.BehaviorName = "CrawlerStatic"; if(crawlerStBrain) m_BehaviorParams.Model = crawlerStBrain; var targetSpawnPos = transform.TransformPoint(new Vector3(0, 0, 1000)); target = Instantiate(targetPrefab, targetSpawnPos, Quaternion.identity, transform); randomizeWalkSpeedEachEpisode = false; break; } case CrawlerAgentBehaviorType.CrawlerStaticVariableSpeed : { m_BehaviorParams.BehaviorName = "CrawlerStaticVariableSpeed"; if(crawlerStVSBrain) m_BehaviorParams.Model = crawlerStVSBrain; var targetSpawnPos = transform.TransformPoint(new Vector3(0, 0, 1000)); target = Instantiate(targetPrefab, targetSpawnPos, Quaternion.identity, transform); randomizeWalkSpeedEachEpisode = true; break; } } m_OrientationCube = GetComponentInChildren(); m_DirectionIndicator = GetComponentInChildren(); m_JdController = GetComponent(); //Setup each body part m_JdController.SetupBodyPart(body); m_JdController.SetupBodyPart(leg0Upper); m_JdController.SetupBodyPart(leg0Lower); m_JdController.SetupBodyPart(leg1Upper); m_JdController.SetupBodyPart(leg1Lower); m_JdController.SetupBodyPart(leg2Upper); m_JdController.SetupBodyPart(leg2Lower); m_JdController.SetupBodyPart(leg3Upper); m_JdController.SetupBodyPart(leg3Lower); } /// /// Loop over body parts and reset them to initial conditions. /// public override void OnEpisodeBegin() { foreach (var bodyPart in m_JdController.bodyPartsDict.Values) { bodyPart.Reset(bodyPart); } //Random start rotation to help generalize body.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0); UpdateOrientationObjects(); //Set our goal walking speed TargetWalkingSpeed = randomizeWalkSpeedEachEpisode ? Random.Range(0.1f, m_maxWalkingSpeed) : TargetWalkingSpeed; } /// /// Add relevant information on each body part to observations. /// public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) { //GROUND CHECK sensor.AddObservation(bp.groundContact.touchingGround); // Is this bp touching the ground if (bp.rb.transform != body) { sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit); } } /// /// Loop over body parts to add them to observation. /// public override void CollectObservations(VectorSensor sensor) { var cubeForward = m_OrientationCube.transform.forward; //velocity we want to match var velGoal = cubeForward * TargetWalkingSpeed; //ragdoll's avg vel var avgVel = GetAvgVelocity(); //current ragdoll velocity. normalized sensor.AddObservation(Vector3.Distance(velGoal, avgVel)); //avg body vel relative to cube sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel)); //vel goal relative to cube sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal)); //rotation delta sensor.AddObservation(Quaternion.FromToRotation(body.forward, cubeForward)); //Add pos of target relative to orientation cube sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(target.transform.position)); RaycastHit hit; float maxRaycastDist = 10; if (Physics.Raycast(body.position, Vector3.down, out hit, maxRaycastDist)) { sensor.AddObservation(hit.distance / maxRaycastDist); } else sensor.AddObservation(1); foreach (var bodyPart in m_JdController.bodyPartsList) { CollectObservationBodyPart(bodyPart, sensor); } } // /// // /// Loop over body parts to add them to observation. // /// // public override void CollectObservations(VectorSensor sensor) // { // //Add pos of target relative to orientation cube // sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(target.transform.position)); // // RaycastHit hit; // float maxRaycastDist = 10; // if (Physics.Raycast(body.position, Vector3.down, out hit, maxRaycastDist)) // { // sensor.AddObservation(hit.distance / maxRaycastDist); // } // else // sensor.AddObservation(1); // // foreach (var bodyPart in m_JdController.bodyPartsList) // { // CollectObservationBodyPart(bodyPart, sensor); // } // } public override void OnActionReceived(ActionBuffers actionBuffers) { // The dictionary with all the body parts in it are in the jdController var bpDict = m_JdController.bodyPartsDict; var continuousActions = actionBuffers.ContinuousActions; var i = -1; // Pick a new target joint rotation bpDict[leg0Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[leg1Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[leg2Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[leg3Upper].SetJointTargetRotation(continuousActions[++i], continuousActions[++i], 0); bpDict[leg0Lower].SetJointTargetRotation(continuousActions[++i], 0, 0); bpDict[leg1Lower].SetJointTargetRotation(continuousActions[++i], 0, 0); bpDict[leg2Lower].SetJointTargetRotation(continuousActions[++i], 0, 0); bpDict[leg3Lower].SetJointTargetRotation(continuousActions[++i], 0, 0); // Update joint strength bpDict[leg0Upper].SetJointStrength(continuousActions[++i]); bpDict[leg1Upper].SetJointStrength(continuousActions[++i]); bpDict[leg2Upper].SetJointStrength(continuousActions[++i]); bpDict[leg3Upper].SetJointStrength(continuousActions[++i]); bpDict[leg0Lower].SetJointStrength(continuousActions[++i]); bpDict[leg1Lower].SetJointStrength(continuousActions[++i]); bpDict[leg2Lower].SetJointStrength(continuousActions[++i]); bpDict[leg3Lower].SetJointStrength(continuousActions[++i]); } void FixedUpdate() { UpdateOrientationObjects(); // If enabled the feet will light up green when the foot is grounded. // This is just a visualization and isn't necessary for function if (useFootGroundedVisualization) { foot0.material = m_JdController.bodyPartsDict[leg0Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; foot1.material = m_JdController.bodyPartsDict[leg1Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; foot2.material = m_JdController.bodyPartsDict[leg2Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; foot3.material = m_JdController.bodyPartsDict[leg3Lower].groundContact.touchingGround ? groundedMaterial : unGroundedMaterial; } var cubeForward = m_OrientationCube.transform.forward; // Set reward for this step according to mixture of the following elements. // a. Match target speed //This reward will approach 1 if it matches perfectly and approach zero as it deviates var matchSpeedReward = GetMatchingVelocityReward(cubeForward * TargetWalkingSpeed, GetAvgVelocity()); //Check for NaNs if (float.IsNaN(matchSpeedReward)) { throw new ArgumentException( "NaN in moveTowardsTargetReward.\n" + $" cubeForward: {cubeForward}\n" + $" hips.velocity: {m_JdController.bodyPartsDict[body].rb.velocity}\n" + $" maximumWalkingSpeed: {m_maxWalkingSpeed}" ); } // b. Rotation alignment with target direction. //This reward will approach 1 if it faces the target direction perfectly and approach zero as it deviates var lookAtTargetReward = (Vector3.Dot(cubeForward, body.forward) + 1) * .5F; //Check for NaNs if (float.IsNaN(lookAtTargetReward)) { throw new ArgumentException( "NaN in lookAtTargetReward.\n" + $" cubeForward: {cubeForward}\n" + $" body.forward: {body.forward}" ); } AddReward(matchSpeedReward * lookAtTargetReward); } //Update OrientationCube and DirectionIndicator void UpdateOrientationObjects() { m_WorldDirToWalk = target.position - body.position; m_OrientationCube.UpdateOrientation(body, target); if (m_DirectionIndicator) { m_DirectionIndicator.MatchOrientation(m_OrientationCube.transform); } } //Returns the average velocity of all of the body parts //Using the velocity of the hips only has shown to result in more erratic movement from the limbs, so... //...using the average helps prevent this erratic movement Vector3 GetAvgVelocity() { Vector3 velSum = Vector3.zero; Vector3 avgVel = Vector3.zero; //ALL RBS int numOfRB = 0; foreach (var item in m_JdController.bodyPartsList) { numOfRB++; velSum += item.rb.velocity; } avgVel = velSum / numOfRB; return avgVel; } //normalized value of the difference in avg speed vs goal walking speed. public float GetMatchingVelocityReward(Vector3 velocityGoal, Vector3 actualVelocity) { //distance between our actual velocity and goal velocity var velDeltaMagnitude = Mathf.Clamp(Vector3.Distance(actualVelocity, velocityGoal), 0, TargetWalkingSpeed); //return the value on a declining sigmoid shaped curve that decays from 1 to 0 //This reward will approach 1 if it matches perfectly and approach zero as it deviates return Mathf.Pow(1 - Mathf.Pow(velDeltaMagnitude / TargetWalkingSpeed, 2), 2); } /// /// Agent touched the target /// public void TouchedTarget() { AddReward(1f); } }