Vilmantas Balasevicius
5 年前
当前提交
1db18bd6
共有 4 个文件被更改,包括 496 次插入 和 0 次删除
-
292UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedCrawlerAgent.cs
-
3UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedCrawlerAgent.cs.meta
-
198UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedJointDriveController.cs
-
3UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/ArticulatedJointDriveController.cs.meta
|
|||
using UnityEngine; |
|||
using MLAgents; |
|||
|
|||
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
|
|||
public class ArticulatedCrawlerAgent : Agent |
|||
{ |
|||
[Header("Target To Walk Towards")][Space(10)] |
|||
public Transform target; |
|||
|
|||
public Transform ground; |
|||
public bool detectTargets; |
|||
public bool targetIsStatic = false; |
|||
public bool respawnTargetWhenTouched; |
|||
public float targetSpawnRadius; |
|||
|
|||
[Header("Body Parts")][Space(10)] public Transform body; |
|||
public Transform leg0Upper; |
|||
public Transform leg0Lower; |
|||
public Transform leg1Upper; |
|||
public Transform leg1Lower; |
|||
public Transform leg2Upper; |
|||
public Transform leg2Lower; |
|||
public Transform leg3Upper; |
|||
public Transform leg3Lower; |
|||
|
|||
[Header("Joint Settings")][Space(10)] JointDriveController m_JdController; |
|||
Vector3 m_DirToTarget; |
|||
float m_MovingTowardsDot; |
|||
float m_FacingDot; |
|||
|
|||
[Header("Reward Functions To Use")][Space(10)] |
|||
public bool rewardMovingTowardsTarget; // Agent should move towards target
|
|||
|
|||
public bool rewardFacingTarget; // Agent should face the target
|
|||
public bool rewardUseTimePenalty; // Hurry up
|
|||
|
|||
[Header("Foot Grounded Visualization")][Space(10)] |
|||
public bool useFootGroundedVisualization; |
|||
|
|||
public MeshRenderer foot0; |
|||
public MeshRenderer foot1; |
|||
public MeshRenderer foot2; |
|||
public MeshRenderer foot3; |
|||
public Material groundedMaterial; |
|||
public Material unGroundedMaterial; |
|||
bool m_IsNewDecisionStep; |
|||
int m_CurrentDecisionStep; |
|||
|
|||
Quaternion m_LookRotation; |
|||
Matrix4x4 m_TargetDirMatrix; |
|||
|
|||
public override void InitializeAgent() |
|||
{ |
|||
m_JdController = GetComponent<JointDriveController>(); |
|||
m_CurrentDecisionStep = 1; |
|||
m_DirToTarget = target.position - body.position; |
|||
|
|||
|
|||
//Setup each body part
|
|||
m_JdController.SetupBodyPart(body); |
|||
m_JdController.SetupBodyPart(leg0Upper); |
|||
m_JdController.SetupBodyPart(leg0Lower); |
|||
m_JdController.SetupBodyPart(leg1Upper); |
|||
m_JdController.SetupBodyPart(leg1Lower); |
|||
m_JdController.SetupBodyPart(leg2Upper); |
|||
m_JdController.SetupBodyPart(leg2Lower); |
|||
m_JdController.SetupBodyPart(leg3Upper); |
|||
m_JdController.SetupBodyPart(leg3Lower); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// We only need to change the joint settings based on decision freq.
|
|||
/// </summary>
|
|||
public void IncrementDecisionTimer() |
|||
{ |
|||
if (m_CurrentDecisionStep == agentParameters.numberOfActionsBetweenDecisions |
|||
|| agentParameters.numberOfActionsBetweenDecisions == 1) |
|||
{ |
|||
m_CurrentDecisionStep = 1; |
|||
m_IsNewDecisionStep = true; |
|||
} |
|||
else |
|||
{ |
|||
m_CurrentDecisionStep++; |
|||
m_IsNewDecisionStep = false; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Add relevant information on each body part to observations.
|
|||
/// </summary>
|
|||
public void CollectObservationBodyPart(BodyPart bp) |
|||
{ |
|||
var rb = bp.rb; |
|||
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
|
|||
|
|||
var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity); |
|||
AddVectorObs(velocityRelativeToLookRotationToTarget); |
|||
|
|||
var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity); |
|||
AddVectorObs(angularVelocityRelativeToLookRotationToTarget); |
|||
|
|||
if (bp.rb.transform != body) |
|||
{ |
|||
var localPosRelToBody = body.InverseTransformPoint(rb.position); |
|||
AddVectorObs(localPosRelToBody); |
|||
AddVectorObs(bp.currentXNormalizedRot); // Current x rot
|
|||
AddVectorObs(bp.currentYNormalizedRot); // Current y rot
|
|||
AddVectorObs(bp.currentZNormalizedRot); // Current z rot
|
|||
AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit); |
|||
} |
|||
} |
|||
|
|||
public override void CollectObservations() |
|||
{ |
|||
m_JdController.GetCurrentJointForces(); |
|||
|
|||
// Update pos to target
|
|||
m_DirToTarget = target.position - body.position; |
|||
m_LookRotation = Quaternion.LookRotation(m_DirToTarget); |
|||
m_TargetDirMatrix = Matrix4x4.TRS(Vector3.zero, m_LookRotation, Vector3.one); |
|||
|
|||
RaycastHit hit; |
|||
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f)) |
|||
{ |
|||
AddVectorObs(hit.distance); |
|||
} |
|||
else |
|||
AddVectorObs(10.0f); |
|||
|
|||
// Forward & up to help with orientation
|
|||
var bodyForwardRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.forward); |
|||
AddVectorObs(bodyForwardRelativeToLookRotationToTarget); |
|||
|
|||
var bodyUpRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.up); |
|||
AddVectorObs(bodyUpRelativeToLookRotationToTarget); |
|||
|
|||
foreach (var bodyPart in m_JdController.bodyPartsDict.Values) |
|||
{ |
|||
CollectObservationBodyPart(bodyPart); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Agent touched the target
|
|||
/// </summary>
|
|||
public void TouchedTarget() |
|||
{ |
|||
AddReward(1f); |
|||
if (respawnTargetWhenTouched) |
|||
{ |
|||
GetRandomTargetPos(); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Moves target to a random position within specified radius.
|
|||
/// </summary>
|
|||
public void GetRandomTargetPos() |
|||
{ |
|||
var newTargetPos = Random.insideUnitSphere * targetSpawnRadius; |
|||
newTargetPos.y = 5; |
|||
target.position = newTargetPos + ground.position; |
|||
} |
|||
|
|||
public override void AgentAction(float[] vectorAction, string textAction) |
|||
{ |
|||
if (detectTargets) |
|||
{ |
|||
foreach (var bodyPart in m_JdController.bodyPartsDict.Values) |
|||
{ |
|||
if (bodyPart.targetContact && !IsDone() && bodyPart.targetContact.touchingTarget) |
|||
{ |
|||
TouchedTarget(); |
|||
} |
|||
} |
|||
} |
|||
|
|||
// If enabled the feet will light up green when the foot is grounded.
|
|||
// This is just a visualization and isn't necessary for function
|
|||
if (useFootGroundedVisualization) |
|||
{ |
|||
foot0.material = m_JdController.bodyPartsDict[leg0Lower].groundContact.touchingGround |
|||
? groundedMaterial |
|||
: unGroundedMaterial; |
|||
foot1.material = m_JdController.bodyPartsDict[leg1Lower].groundContact.touchingGround |
|||
? groundedMaterial |
|||
: unGroundedMaterial; |
|||
foot2.material = m_JdController.bodyPartsDict[leg2Lower].groundContact.touchingGround |
|||
? groundedMaterial |
|||
: unGroundedMaterial; |
|||
foot3.material = m_JdController.bodyPartsDict[leg3Lower].groundContact.touchingGround |
|||
? groundedMaterial |
|||
: unGroundedMaterial; |
|||
} |
|||
|
|||
// Joint update logic only needs to happen when a new decision is made
|
|||
if (m_IsNewDecisionStep) |
|||
{ |
|||
// The dictionary with all the body parts in it are in the jdController
|
|||
var bpDict = m_JdController.bodyPartsDict; |
|||
|
|||
var i = -1; |
|||
// Pick a new target joint rotation
|
|||
bpDict[leg0Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); |
|||
bpDict[leg1Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); |
|||
bpDict[leg2Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); |
|||
bpDict[leg3Upper].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); |
|||
bpDict[leg0Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); |
|||
bpDict[leg1Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); |
|||
bpDict[leg2Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); |
|||
bpDict[leg3Lower].SetJointTargetRotation(vectorAction[++i], 0, 0); |
|||
|
|||
// Update joint strength
|
|||
bpDict[leg0Upper].SetJointStrength(vectorAction[++i]); |
|||
bpDict[leg1Upper].SetJointStrength(vectorAction[++i]); |
|||
bpDict[leg2Upper].SetJointStrength(vectorAction[++i]); |
|||
bpDict[leg3Upper].SetJointStrength(vectorAction[++i]); |
|||
bpDict[leg0Lower].SetJointStrength(vectorAction[++i]); |
|||
bpDict[leg1Lower].SetJointStrength(vectorAction[++i]); |
|||
bpDict[leg2Lower].SetJointStrength(vectorAction[++i]); |
|||
bpDict[leg3Lower].SetJointStrength(vectorAction[++i]); |
|||
} |
|||
|
|||
// Set reward for this step according to mixture of the following elements.
|
|||
if (rewardMovingTowardsTarget) |
|||
{ |
|||
RewardFunctionMovingTowards(); |
|||
} |
|||
|
|||
if (rewardFacingTarget) |
|||
{ |
|||
RewardFunctionFacingTarget(); |
|||
} |
|||
|
|||
if (rewardUseTimePenalty) |
|||
{ |
|||
RewardFunctionTimePenalty(); |
|||
} |
|||
|
|||
IncrementDecisionTimer(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Reward moving towards target & Penalize moving away from target.
|
|||
/// </summary>
|
|||
void RewardFunctionMovingTowards() |
|||
{ |
|||
m_MovingTowardsDot = Vector3.Dot(m_JdController.bodyPartsDict[body].rb.velocity, m_DirToTarget.normalized); |
|||
AddReward(0.03f * m_MovingTowardsDot); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Reward facing target & Penalize facing away from target
|
|||
/// </summary>
|
|||
void RewardFunctionFacingTarget() |
|||
{ |
|||
m_FacingDot = Vector3.Dot(m_DirToTarget.normalized, body.forward); |
|||
AddReward(0.01f * m_FacingDot); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Existential penalty for time-contrained tasks.
|
|||
/// </summary>
|
|||
void RewardFunctionTimePenalty() |
|||
{ |
|||
AddReward(-0.001f); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Loop over body parts and reset them to initial conditions.
|
|||
/// </summary>
|
|||
public override void AgentReset() |
|||
{ |
|||
if (m_DirToTarget != Vector3.zero) |
|||
{ |
|||
transform.rotation = Quaternion.LookRotation(m_DirToTarget); |
|||
} |
|||
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f)); |
|||
|
|||
foreach (var bodyPart in m_JdController.bodyPartsDict.Values) |
|||
{ |
|||
bodyPart.Reset(bodyPart); |
|||
} |
|||
if (!targetIsStatic) |
|||
{ |
|||
GetRandomTargetPos(); |
|||
} |
|||
m_IsNewDecisionStep = true; |
|||
m_CurrentDecisionStep = 1; |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d7fe662b41d94ad2b94ff2f5fc596210 |
|||
timeCreated: 1572439781 |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using UnityEngine.Serialization; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// Used to store relevant information for acting and learning for each body part in agent.
|
|||
/// </summary>
|
|||
[System.Serializable] |
|||
public class ArticulationBodyPart |
|||
{ |
|||
//[Header("Body Part Info")][Space(10)] public ConfigurableJoint joint;
|
|||
public ArticulationBody arb; |
|||
[HideInInspector] public Vector3 startingPos; |
|||
[HideInInspector] public Quaternion startingRot; |
|||
|
|||
[Header("Ground & Target Contact")][Space(10)] |
|||
public GroundContact groundContact; |
|||
|
|||
public TargetContact targetContact; |
|||
|
|||
[FormerlySerializedAs("thisJDController")] |
|||
[HideInInspector] public ArticulatedJointDriveController thisJdController; |
|||
|
|||
[Header("Current Joint Settings")][Space(10)] |
|||
public Vector3 currentEularJointRotation; |
|||
|
|||
[HideInInspector] public float currentStrength; |
|||
public float currentXNormalizedRot; |
|||
public float currentYNormalizedRot; |
|||
public float currentZNormalizedRot; |
|||
|
|||
[Header("Other Debug Info")][Space(10)] |
|||
public Vector3 currentJointForce; |
|||
|
|||
public float currentJointForceSqrMag; |
|||
public Vector3 currentJointTorque; |
|||
public float currentJointTorqueSqrMag; |
|||
public AnimationCurve jointForceCurve = new AnimationCurve(); |
|||
public AnimationCurve jointTorqueCurve = new AnimationCurve(); |
|||
|
|||
/// <summary>
|
|||
/// Reset body part to initial configuration.
|
|||
/// </summary>
|
|||
public void Reset(ArticulationBodyPart bp) |
|||
{ |
|||
bp.arb.transform.position = bp.startingPos; |
|||
bp.arb.transform.rotation = bp.startingRot; |
|||
bp.arb.velocity = Vector3.zero; |
|||
bp.arb.angularVelocity = Vector3.zero; |
|||
if (bp.groundContact) |
|||
{ |
|||
bp.groundContact.touchingGround = false; |
|||
} |
|||
|
|||
if (bp.targetContact) |
|||
{ |
|||
bp.targetContact.touchingTarget = false; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Apply torque according to defined goal `x, y, z` angle and force `strength`.
|
|||
/// </summary>
|
|||
public void SetJointTargetRotation(float x, float y, float z) |
|||
{ |
|||
x = (x + 1f) * 0.5f; |
|||
y = (y + 1f) * 0.5f; |
|||
z = (z + 1f) * 0.5f; |
|||
|
|||
var xDrive = arb.xDrive; |
|||
var yDrive = arb.yDrive; |
|||
var zDrive = arb.zDrive; |
|||
|
|||
|
|||
|
|||
var xRot = Mathf.Lerp(xDrive.lowerLimit, xDrive.upperLimit, x); |
|||
var yRot = Mathf.Lerp(yDrive.lowerLimit, yDrive.upperLimit, y); |
|||
var zRot = Mathf.Lerp(zDrive.lowerLimit, zDrive.upperLimit, z); |
|||
|
|||
currentXNormalizedRot = |
|||
Mathf.InverseLerp(xDrive.lowerLimit, xDrive.upperLimit, xRot); |
|||
|
|||
// What is this ? Vilmantas Why lowerLimit is not used ?
|
|||
currentYNormalizedRot = Mathf.InverseLerp(-yDrive.upperLimit, yDrive.upperLimit, yRot); |
|||
currentZNormalizedRot = Mathf.InverseLerp(-zDrive.upperLimit, zDrive.upperLimit, zRot); |
|||
|
|||
//joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot); // Original code
|
|||
xDrive.target = xRot; yDrive.target = yRot; zDrive.target = zRot; |
|||
|
|||
arb.xDrive = xDrive; arb.yDrive = yDrive; arb.zDrive = zDrive; |
|||
|
|||
currentEularJointRotation = new Vector3(xRot, yRot, zRot); |
|||
} |
|||
|
|||
public void SetJointStrength(float strength) |
|||
{ |
|||
ArticulationDrive drive = arb.xDrive; |
|||
|
|||
|
|||
var rawVal = (strength + 1f) * 0.5f * thisJdController.maxJointForceLimit; |
|||
|
|||
drive.stiffness = thisJdController.maxJointSpring; |
|||
drive.damping = thisJdController.jointDampen; |
|||
drive.forceLimit = rawVal; |
|||
|
|||
// Slerp drive does not exist, so we try to set strength for each axis individually
|
|||
arb.xDrive = drive; |
|||
arb.yDrive = drive; |
|||
arb.zDrive = drive; |
|||
//joint.slerpDrive = jd;
|
|||
currentStrength = rawVal; |
|||
} |
|||
} |
|||
|
|||
public class ArticulatedJointDriveController : MonoBehaviour |
|||
{ |
|||
[Header("Joint Drive Settings")][Space(10)] |
|||
public float maxJointSpring; |
|||
|
|||
public float jointDampen; |
|||
public float maxJointForceLimit; |
|||
float m_FacingDot; |
|||
|
|||
[HideInInspector] public Dictionary<Transform, ArticulationBodyPart> bodyPartsDict = new Dictionary<Transform, ArticulationBodyPart>(); |
|||
|
|||
[HideInInspector] public List<ArticulationBodyPart> bodyPartsList = new List<ArticulationBodyPart>(); |
|||
|
|||
/// <summary>
|
|||
/// Create BodyPart object and add it to dictionary.
|
|||
/// </summary>
|
|||
public void SetupBodyPart(Transform t) |
|||
{ |
|||
var bp = new ArticulationBodyPart() |
|||
{ |
|||
arb = t.GetComponent<ArticulationBody>(), |
|||
startingPos = t.position, |
|||
startingRot = t.rotation |
|||
}; |
|||
|
|||
// Does not exist in articulation body
|
|||
//bp.rb.maxAngularVelocity = 100;
|
|||
|
|||
// Add & setup the ground contact script
|
|||
bp.groundContact = t.GetComponent<GroundContact>(); |
|||
if (!bp.groundContact) |
|||
{ |
|||
bp.groundContact = t.gameObject.AddComponent<GroundContact>(); |
|||
bp.groundContact.agent = gameObject.GetComponent<Agent>(); |
|||
} |
|||
else |
|||
{ |
|||
bp.groundContact.agent = gameObject.GetComponent<Agent>(); |
|||
} |
|||
|
|||
// Add & setup the target contact script
|
|||
bp.targetContact = t.GetComponent<TargetContact>(); |
|||
if (!bp.targetContact) |
|||
{ |
|||
bp.targetContact = t.gameObject.AddComponent<TargetContact>(); |
|||
} |
|||
|
|||
bp.thisJdController = this; |
|||
bodyPartsDict.Add(t, bp); |
|||
bodyPartsList.Add(bp); |
|||
} |
|||
|
|||
public void GetCurrentJointForces() |
|||
{ |
|||
foreach (var bodyPart in bodyPartsDict.Values) |
|||
{ |
|||
if (!bodyPart.arb.isRoot) |
|||
{ |
|||
bodyPart.currentJointForce = bodyPart.arb; |
|||
bodyPart.currentJointForceSqrMag = bodyPart.joint.currentForce.magnitude; |
|||
bodyPart.currentJointTorque = bodyPart.joint.currentTorque; |
|||
bodyPart.currentJointTorqueSqrMag = bodyPart.joint.currentTorque.magnitude; |
|||
if (Application.isEditor) |
|||
{ |
|||
if (bodyPart.jointForceCurve.length > 1000) |
|||
{ |
|||
bodyPart.jointForceCurve = new AnimationCurve(); |
|||
} |
|||
|
|||
if (bodyPart.jointTorqueCurve.length > 1000) |
|||
{ |
|||
bodyPart.jointTorqueCurve = new AnimationCurve(); |
|||
} |
|||
|
|||
bodyPart.jointForceCurve.AddKey(Time.time, bodyPart.currentJointForceSqrMag); |
|||
bodyPart.jointTorqueCurve.AddKey(Time.time, bodyPart.currentJointTorqueSqrMag); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d8020ed16eb94c9aac4589a46facb1fa |
|||
timeCreated: 1572439925 |
撰写
预览
正在加载...
取消
保存
Reference in new issue