|
|
|
|
|
|
using MLAgentsExamples; |
|
|
|
using UnityEditor; |
|
|
|
using BodyPart = Unity.MLAgentsExamples.BodyPart; |
|
|
|
[Header("Walking Speed")] |
|
|
|
[Space(10)] |
|
|
|
public float maximumWalkingSpeed = 999; //The max walk velocity magnitude an agent will be rewarded for
|
|
|
|
Vector3 m_WalkDir; |
|
|
|
Quaternion m_WalkDirLookRot; |
|
|
|
|
|
|
|
[Space(10)] |
|
|
|
[Header("Orientation Cube")] |
|
|
|
[Space(10)] |
|
|
|
//This will be used as a stable observation platform for the ragdoll to use.
|
|
|
|
GameObject m_OrientationCube; |
|
|
|
public Transform directionIndicator; |
|
|
|
|
|
|
|
public float targetSpawnRadius; |
|
|
|
public Transform ground; |
|
|
|
public bool detectTargets; |
|
|
|
public bool targetIsStatic; |
|
|
|
public bool respawnTargetWhenTouched; |
|
|
|
|
|
|
|
Vector3 m_DirToTarget; |
|
|
|
[Header("Body Parts")] |
|
|
|
[Space(10)] |
|
|
|
public Transform hips; |
|
|
|
public Transform chest; |
|
|
|
public Transform spine; |
|
|
|
|
|
|
Rigidbody m_SpineRb; |
|
|
|
|
|
|
|
EnvironmentParameters m_ResetParams; |
|
|
|
|
|
|
|
|
|
|
|
//Spawn an orientation cube
|
|
|
|
Vector3 oCubePos = hips.position; |
|
|
|
oCubePos.y = -.45f; |
|
|
|
m_OrientationCube = Instantiate(Resources.Load<GameObject>("OrientationCube"), oCubePos, Quaternion.identity); |
|
|
|
m_OrientationCube.transform.SetParent(transform); |
|
|
|
|
|
|
|
UpdateOrientationCube(); |
|
|
|
|
|
|
|
m_JdController = GetComponent<JointDriveController>(); |
|
|
|
m_JdController.SetupBodyPart(hips); |
|
|
|
m_JdController.SetupBodyPart(chest); |
|
|
|
|
|
|
/// </summary>
|
|
|
|
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor) |
|
|
|
{ |
|
|
|
var rb = bp.rb; |
|
|
|
//GROUND CHECK
|
|
|
|
sensor.AddObservation(rb.velocity); |
|
|
|
sensor.AddObservation(rb.angularVelocity); |
|
|
|
var localPosRelToHips = hips.InverseTransformPoint(rb.position); |
|
|
|
sensor.AddObservation(localPosRelToHips); |
|
|
|
|
|
|
|
//Get velocities in the context of our orientation cube's space
|
|
|
|
//Note: You can get these velocities in world space as well but it may not train as well.
|
|
|
|
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.velocity)); |
|
|
|
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity)); |
|
|
|
|
|
|
|
//Get position relative to hips in the context of our orientation cube's space
|
|
|
|
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(bp.rb.position - hips.position)); //best
|
|
|
|
if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR && |
|
|
|
bp.rb.transform != footL && bp.rb.transform != footR && bp.rb.transform != head) |
|
|
|
if (bp.rb.transform != hips && bp.rb.transform != handL && bp.rb.transform != handR) |
|
|
|
sensor.AddObservation(bp.currentXNormalizedRot); |
|
|
|
sensor.AddObservation(bp.currentYNormalizedRot); |
|
|
|
sensor.AddObservation(bp.currentZNormalizedRot); |
|
|
|
sensor.AddObservation(bp.rb.transform.localRotation); |
|
|
|
|
|
|
|
|
|
|
|
m_JdController.GetCurrentJointForces(); |
|
|
|
|
|
|
|
sensor.AddObservation(m_DirToTarget.normalized); |
|
|
|
sensor.AddObservation(m_JdController.bodyPartsDict[hips].rb.position); |
|
|
|
sensor.AddObservation(hips.forward); |
|
|
|
sensor.AddObservation(hips.up); |
|
|
|
sensor.AddObservation(Quaternion.FromToRotation(hips.forward, m_OrientationCube.transform.forward)); |
|
|
|
sensor.AddObservation(Quaternion.FromToRotation(head.forward, m_OrientationCube.transform.forward)); |
|
|
|
|
|
|
|
foreach (var bodyPart in m_JdController.bodyPartsDict.Values) |
|
|
|
{ |
|
|
|
|
|
|
bpDict[footR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]); |
|
|
|
bpDict[footL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], vectorAction[++i]); |
|
|
|
|
|
|
|
|
|
|
|
bpDict[armL].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); |
|
|
|
bpDict[armR].SetJointTargetRotation(vectorAction[++i], vectorAction[++i], 0); |
|
|
|
bpDict[forearmL].SetJointTargetRotation(vectorAction[++i], 0, 0); |
|
|
|
|
|
|
bpDict[forearmR].SetJointStrength(vectorAction[++i]); |
|
|
|
} |
|
|
|
|
|
|
|
void UpdateOrientationCube() |
|
|
|
{ |
|
|
|
//FACING DIR
|
|
|
|
m_WalkDir = target.position - m_OrientationCube.transform.position; |
|
|
|
m_WalkDir.y = 0; //flatten dir on the y
|
|
|
|
m_WalkDirLookRot = Quaternion.LookRotation(m_WalkDir); //get our look rot to the target
|
|
|
|
|
|
|
|
//UPDATE ORIENTATION CUBE POS & ROT
|
|
|
|
m_OrientationCube.transform.position = hips.position; |
|
|
|
m_OrientationCube.transform.rotation = m_WalkDirLookRot; |
|
|
|
|
|
|
|
directionIndicator.position = new Vector3(hips.position.x, directionIndicator.position.y, hips.position.z); |
|
|
|
directionIndicator.rotation = m_WalkDirLookRot; |
|
|
|
} |
|
|
|
|
|
|
|
if (detectTargets) |
|
|
|
{ |
|
|
|
foreach (var bodyPart in m_JdController.bodyPartsDict.Values) |
|
|
|
{ |
|
|
|
if (bodyPart.targetContact && bodyPart.targetContact.touchingTarget) |
|
|
|
{ |
|
|
|
TouchedTarget(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
UpdateOrientationCube(); |
|
|
|
|
|
|
|
// d. Discourage head movement.
|
|
|
|
m_DirToTarget = target.position - m_JdController.bodyPartsDict[hips].rb.position; |
|
|
|
+0.03f * Vector3.Dot(m_DirToTarget.normalized, m_JdController.bodyPartsDict[hips].rb.velocity) |
|
|
|
+ 0.01f * Vector3.Dot(m_DirToTarget.normalized, hips.forward) |
|
|
|
+ 0.02f * (head.position.y - hips.position.y) |
|
|
|
- 0.01f * Vector3.Distance(m_JdController.bodyPartsDict[head].rb.velocity, |
|
|
|
m_JdController.bodyPartsDict[hips].rb.velocity) |
|
|
|
+0.02f * Vector3.Dot(m_OrientationCube.transform.forward, |
|
|
|
Vector3.ClampMagnitude(m_JdController.bodyPartsDict[hips].rb.velocity, maximumWalkingSpeed)) |
|
|
|
+ 0.01f * Vector3.Dot(m_OrientationCube.transform.forward, head.forward) |
|
|
|
+ 0.005f * (head.position.y - footL.position.y) |
|
|
|
+ 0.005f * (head.position.y - footR.position.y) |
|
|
|
/// Agent touched the target
|
|
|
|
/// </summary>
|
|
|
|
public void TouchedTarget() |
|
|
|
{ |
|
|
|
AddReward(1f); |
|
|
|
if (respawnTargetWhenTouched) |
|
|
|
{ |
|
|
|
GetRandomTargetPos(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Moves target to a random position within specified radius.
|
|
|
|
/// </summary>
|
|
|
|
public void GetRandomTargetPos() |
|
|
|
{ |
|
|
|
var newTargetPos = Random.insideUnitSphere * targetSpawnRadius; |
|
|
|
newTargetPos.y = 5; |
|
|
|
target.position = newTargetPos + ground.position; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
if (m_DirToTarget != Vector3.zero) |
|
|
|
foreach (var bodyPart in m_JdController.bodyPartsDict.Values) |
|
|
|
transform.rotation = Quaternion.LookRotation(m_DirToTarget); |
|
|
|
bodyPart.Reset(bodyPart); |
|
|
|
|
|
|
|
//Random start rotation
|
|
|
|
transform.rotation = Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0); |
|
|
|
|
|
|
|
UpdateOrientationCube(); |
|
|
|
foreach (var bodyPart in m_JdController.bodyPartsDict.Values) |
|
|
|
if (detectTargets && !targetIsStatic) |
|
|
|
bodyPart.Reset(bodyPart); |
|
|
|
GetRandomTargetPos(); |
|
|
|
|
|
|
|
SetResetParameters(); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
public void SetResetParameters() |
|
|
|
{ |
|
|
|
SetTorsoMass(); |
|
|
|
} |
|
|
|
|
|
|
|
private void OnDrawGizmosSelected() |
|
|
|
{ |
|
|
|
if (Application.isPlaying) |
|
|
|
{ |
|
|
|
Gizmos.color = Color.green; |
|
|
|
Gizmos.matrix = m_OrientationCube.transform.localToWorldMatrix; |
|
|
|
Gizmos.DrawWireCube(Vector3.zero, m_OrientationCube.transform.localScale); |
|
|
|
Gizmos.DrawRay(Vector3.zero, Vector3.forward); |
|
|
|
} |
|
|
|
} |
|
|
|
} |