浏览代码

Initial changes for articulated crawler.

/PhysXArticulations20201
Vilmantas Balasevicius 5 年前
当前提交
1db18bd6
共有 4 个文件被更改,包括 496 次插入0 次删除
  1. 292
      UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedCrawlerAgent.cs
  2. 3
      UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedCrawlerAgent.cs.meta
  3. 198
      UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedJointDriveController.cs
  4. 3
      UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedJointDriveController.cs.meta

292
UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedCrawlerAgent.cs


using UnityEngine;
using MLAgents;
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class ArticulatedCrawlerAgent : Agent
{
[Header("Target To Walk Towards")][Space(10)]
public Transform target;
public Transform ground;
public bool detectTargets;
public bool targetIsStatic = false;
public bool respawnTargetWhenTouched;
public float targetSpawnRadius;
[Header("Body Parts")][Space(10)] public Transform body;
public Transform leg0Upper;
public Transform leg0Lower;
public Transform leg1Upper;
public Transform leg1Lower;
public Transform leg2Upper;
public Transform leg2Lower;
public Transform leg3Upper;
public Transform leg3Lower;
[Header("Joint Settings")][Space(10)] JointDriveController m_JdController;
Vector3 m_DirToTarget;
float m_MovingTowardsDot;
float m_FacingDot;
[Header("Reward Functions To Use")][Space(10)]
public bool rewardMovingTowardsTarget; // Agent should move towards target
public bool rewardFacingTarget; // Agent should face the target
public bool rewardUseTimePenalty; // Hurry up
[Header("Foot Grounded Visualization")][Space(10)]
public bool useFootGroundedVisualization;
public MeshRenderer foot0;
public MeshRenderer foot1;
public MeshRenderer foot2;
public MeshRenderer foot3;
public Material groundedMaterial;
public Material unGroundedMaterial;
bool m_IsNewDecisionStep;
int m_CurrentDecisionStep;
Quaternion m_LookRotation;
Matrix4x4 m_TargetDirMatrix;
public override void InitializeAgent()
{
m_JdController = GetComponent<JointDriveController>();
m_CurrentDecisionStep = 1;
m_DirToTarget = target.position - body.position;
//Setup each body part
m_JdController.SetupBodyPart(body);
m_JdController.SetupBodyPart(leg0Upper);
m_JdController.SetupBodyPart(leg0Lower);
m_JdController.SetupBodyPart(leg1Upper);
m_JdController.SetupBodyPart(leg1Lower);
m_JdController.SetupBodyPart(leg2Upper);
m_JdController.SetupBodyPart(leg2Lower);
m_JdController.SetupBodyPart(leg3Upper);
m_JdController.SetupBodyPart(leg3Lower);
}
/// <summary>
/// We only need to change the joint settings based on decision freq.
/// </summary>
public void IncrementDecisionTimer()
{
if (m_CurrentDecisionStep == agentParameters.numberOfActionsBetweenDecisions
|| agentParameters.numberOfActionsBetweenDecisions == 1)
{
m_CurrentDecisionStep = 1;
m_IsNewDecisionStep = true;
}
else
{
m_CurrentDecisionStep++;
m_IsNewDecisionStep = false;
}
}
/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp)
{
var rb = bp.rb;
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity);
AddVectorObs(velocityRelativeToLookRotationToTarget);
var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
AddVectorObs(angularVelocityRelativeToLookRotationToTarget);
if (bp.rb.transform != body)
{
var localPosRelToBody = body.InverseTransformPoint(rb.position);
AddVectorObs(localPosRelToBody);
AddVectorObs(bp.currentXNormalizedRot); // Current x rot
AddVectorObs(bp.currentYNormalizedRot); // Current y rot
AddVectorObs(bp.currentZNormalizedRot); // Current z rot
AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}
public override void CollectObservations()
{
m_JdController.GetCurrentJointForces();
// Update pos to target
m_DirToTarget = target.position - body.position;
m_LookRotation = Quaternion.LookRotation(m_DirToTarget);
m_TargetDirMatrix = Matrix4x4.TRS(Vector3.zero, m_LookRotation, Vector3.one);
RaycastHit hit;
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f))
{
AddVectorObs(hit.distance);
}
else
AddVectorObs(10.0f);
// Forward & up to help with orientation
var bodyForwardRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.forward);
AddVectorObs(bodyForwardRelativeToLookRotationToTarget);
var bodyUpRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.up);
AddVectorObs(bodyUpRelativeToLookRotationToTarget);
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
CollectObservationBodyPart(bodyPart);
}
}
/// <summary>
/// Agent touched the target
/// </summary>
public void TouchedTarget()
{
AddReward(1f);
if (respawnTargetWhenTouched)
{
GetRandomTargetPos();
}
}
/// <summary>
/// Moves target to a random position within specified radius.
/// </summary>
public void GetRandomTargetPos()
{
var newTargetPos = Random.insideUnitSphere * targetSpawnRadius;
newTargetPos.y = 5;
target.position = newTargetPos + ground.position;
}
public override void AgentAction(float[] vectorAction, string textAction)
{
if (detectTargets)
{
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
if (bodyPart.targetContact && !IsDone() && bodyPart.targetContact.touchingTarget)
{
TouchedTarget();
}
}
}
// If enabled the feet will light up green when the foot is grounded.
// This is just a visualization and isn't necessary for function
if (useFootGroundedVisualization)
{
foot0.material = m_JdController.bodyPartsDict[leg0Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
foot1.material = m_JdController.bodyPartsDict[leg1Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
foot2.material = m_JdController.bodyPartsDict[leg2Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
foot3.material = m_JdController.bodyPartsDict[leg3Lower].groundContact.touchingGround
? groundedMaterial
: unGroundedMaterial;
}
// Joint update logic only needs to happen when a new decision is made
if (m_IsNewDecisionStep)
{
// The dictionary with all the body parts in it are in the jdController
var bpDict = m_JdController.bodyPartsDict;
var i = -1;
// Pick a new target joint rotation
bpDict[leg0Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg1Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg2Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg3Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0);
bpDict[leg0Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg1Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg2Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
bpDict[leg3Lower].SetJointTargetRotation(vectorAction[++i], 0, 0);
// Update joint strength
bpDict[leg0Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg1Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg2Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg3Upper].SetJointStrength(vectorAction[++i]);
bpDict[leg0Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg1Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg2Lower].SetJointStrength(vectorAction[++i]);
bpDict[leg3Lower].SetJointStrength(vectorAction[++i]);
}
// Set reward for this step according to mixture of the following elements.
if (rewardMovingTowardsTarget)
{
RewardFunctionMovingTowards();
}
if (rewardFacingTarget)
{
RewardFunctionFacingTarget();
}
if (rewardUseTimePenalty)
{
RewardFunctionTimePenalty();
}
IncrementDecisionTimer();
}
/// <summary>
/// Reward moving towards target & Penalize moving away from target.
/// </summary>
void RewardFunctionMovingTowards()
{
m_MovingTowardsDot = Vector3.Dot(m_JdController.bodyPartsDict[body].rb.velocity, m_DirToTarget.normalized);
AddReward(0.03f * m_MovingTowardsDot);
}
/// <summary>
/// Reward facing target & Penalize facing away from target
/// </summary>
void RewardFunctionFacingTarget()
{
m_FacingDot = Vector3.Dot(m_DirToTarget.normalized, body.forward);
AddReward(0.01f * m_FacingDot);
}
/// <summary>
/// Existential penalty for time-contrained tasks.
/// </summary>
void RewardFunctionTimePenalty()
{
AddReward(-0.001f);
}
/// <summary>
/// Loop over body parts and reset them to initial conditions.
/// </summary>
public override void AgentReset()
{
if (m_DirToTarget != Vector3.zero)
{
transform.rotation = Quaternion.LookRotation(m_DirToTarget);
}
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f));
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
bodyPart.Reset(bodyPart);
}
if (!targetIsStatic)
{
GetRandomTargetPos();
}
m_IsNewDecisionStep = true;
m_CurrentDecisionStep = 1;
}
}

3
UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedCrawlerAgent.cs.meta


fileFormatVersion: 2
guid: d7fe662b41d94ad2b94ff2f5fc596210
timeCreated: 1572439781

198
UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedJointDriveController.cs


using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Serialization;
namespace MLAgents
{
/// <summary>
/// Used to store relevant information for acting and learning for each body part in agent.
/// </summary>
[System.Serializable]
public class ArticulationBodyPart
{
//[Header("Body Part Info")][Space(10)] public ConfigurableJoint joint;
public ArticulationBody arb;
[HideInInspector] public Vector3 startingPos;
[HideInInspector] public Quaternion startingRot;
[Header("Ground & Target Contact")][Space(10)]
public GroundContact groundContact;
public TargetContact targetContact;
[FormerlySerializedAs("thisJDController")]
[HideInInspector] public ArticulatedJointDriveController thisJdController;
[Header("Current Joint Settings")][Space(10)]
public Vector3 currentEularJointRotation;
[HideInInspector] public float currentStrength;
public float currentXNormalizedRot;
public float currentYNormalizedRot;
public float currentZNormalizedRot;
[Header("Other Debug Info")][Space(10)]
public Vector3 currentJointForce;
public float currentJointForceSqrMag;
public Vector3 currentJointTorque;
public float currentJointTorqueSqrMag;
public AnimationCurve jointForceCurve = new AnimationCurve();
public AnimationCurve jointTorqueCurve = new AnimationCurve();
/// <summary>
/// Reset body part to initial configuration.
/// </summary>
public void Reset(ArticulationBodyPart bp)
{
bp.arb.transform.position = bp.startingPos;
bp.arb.transform.rotation = bp.startingRot;
bp.arb.velocity = Vector3.zero;
bp.arb.angularVelocity = Vector3.zero;
if (bp.groundContact)
{
bp.groundContact.touchingGround = false;
}
if (bp.targetContact)
{
bp.targetContact.touchingTarget = false;
}
}
/// <summary>
/// Apply torque according to defined goal `x, y, z` angle and force `strength`.
/// </summary>
public void SetJointTargetRotation(float x, float y, float z)
{
x = (x + 1f) * 0.5f;
y = (y + 1f) * 0.5f;
z = (z + 1f) * 0.5f;
var xDrive = arb.xDrive;
var yDrive = arb.yDrive;
var zDrive = arb.zDrive;
var xRot = Mathf.Lerp(xDrive.lowerLimit, xDrive.upperLimit, x);
var yRot = Mathf.Lerp(yDrive.lowerLimit, yDrive.upperLimit, y);
var zRot = Mathf.Lerp(zDrive.lowerLimit, zDrive.upperLimit, z);
currentXNormalizedRot =
Mathf.InverseLerp(xDrive.lowerLimit, xDrive.upperLimit, xRot);
// What is this ? Vilmantas Why lowerLimit is not used ?
currentYNormalizedRot = Mathf.InverseLerp(-yDrive.upperLimit, yDrive.upperLimit, yRot);
currentZNormalizedRot = Mathf.InverseLerp(-zDrive.upperLimit, zDrive.upperLimit, zRot);
//joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot); // Original code
xDrive.target = xRot; yDrive.target = yRot; zDrive.target = zRot;
arb.xDrive = xDrive; arb.yDrive = yDrive; arb.zDrive = zDrive;
currentEularJointRotation = new Vector3(xRot, yRot, zRot);
}
public void SetJointStrength(float strength)
{
ArticulationDrive drive = arb.xDrive;
var rawVal = (strength + 1f) * 0.5f * thisJdController.maxJointForceLimit;
drive.stiffness = thisJdController.maxJointSpring;
drive.damping = thisJdController.jointDampen;
drive.forceLimit = rawVal;
// Slerp drive does not exist, so we try to set strength for each axis individually
arb.xDrive = drive;
arb.yDrive = drive;
arb.zDrive = drive;
//joint.slerpDrive = jd;
currentStrength = rawVal;
}
}
public class ArticulatedJointDriveController : MonoBehaviour
{
[Header("Joint Drive Settings")][Space(10)]
public float maxJointSpring;
public float jointDampen;
public float maxJointForceLimit;
float m_FacingDot;
[HideInInspector] public Dictionary<Transform, ArticulationBodyPart> bodyPartsDict = new Dictionary<Transform, ArticulationBodyPart>();
[HideInInspector] public List<ArticulationBodyPart> bodyPartsList = new List<ArticulationBodyPart>();
/// <summary>
/// Create BodyPart object and add it to dictionary.
/// </summary>
public void SetupBodyPart(Transform t)
{
var bp = new ArticulationBodyPart()
{
arb = t.GetComponent<ArticulationBody>(),
startingPos = t.position,
startingRot = t.rotation
};
// Does not exist in articulation body
//bp.rb.maxAngularVelocity = 100;
// Add & setup the ground contact script
bp.groundContact = t.GetComponent<GroundContact>();
if (!bp.groundContact)
{
bp.groundContact = t.gameObject.AddComponent<GroundContact>();
bp.groundContact.agent = gameObject.GetComponent<Agent>();
}
else
{
bp.groundContact.agent = gameObject.GetComponent<Agent>();
}
// Add & setup the target contact script
bp.targetContact = t.GetComponent<TargetContact>();
if (!bp.targetContact)
{
bp.targetContact = t.gameObject.AddComponent<TargetContact>();
}
bp.thisJdController = this;
bodyPartsDict.Add(t, bp);
bodyPartsList.Add(bp);
}
public void GetCurrentJointForces()
{
foreach (var bodyPart in bodyPartsDict.Values)
{
if (!bodyPart.arb.isRoot)
{
bodyPart.currentJointForce = bodyPart.arb;
bodyPart.currentJointForceSqrMag = bodyPart.joint.currentForce.magnitude;
bodyPart.currentJointTorque = bodyPart.joint.currentTorque;
bodyPart.currentJointTorqueSqrMag = bodyPart.joint.currentTorque.magnitude;
if (Application.isEditor)
{
if (bodyPart.jointForceCurve.length > 1000)
{
bodyPart.jointForceCurve = new AnimationCurve();
}
if (bodyPart.jointTorqueCurve.length > 1000)
{
bodyPart.jointTorqueCurve = new AnimationCurve();
}
bodyPart.jointForceCurve.AddKey(Time.time, bodyPart.currentJointForceSqrMag);
bodyPart.jointTorqueCurve.AddKey(Time.time, bodyPart.currentJointTorqueSqrMag);
}
}
}
}
}
}

3
UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedJointDriveController.cs.meta


fileFormatVersion: 2
guid: d8020ed16eb94c9aac4589a46facb1fa
timeCreated: 1572439925
正在加载...
取消
保存