[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class WormAgent : Agent
{
[Header("Target To Walk Towards")]
[Space(10)]
[Header("Target To Walk Towards")] [ Space ( 1 0 ) ]
public bool detectTargets ;
public bool targetIsStatic ;
public bool respawnTargetWhenTouched ;
public float targetSpawnRadius ;
[Header("Body Parts")] [ Space ( 1 0 ) ]
public Transform bodySegment0 ;
[Header("Body Parts")] [ Space ( 1 0 ) ] public Transform bodySegment0 ;
[Header("Joint Settings")] [ Space ( 1 0 ) ]
JointDriveController m_JdController ;
[Header("Joint Settings")] [ Space ( 1 0 ) ] JointDriveController m_JdController ;
[Header("Reward Functions To Use")]
[Space(10)]
[Header("Reward Functions To Use")] [ Space ( 1 0 ) ]
public bool rewardMovingTowardsTarget ; // Agent should move towards target
public bool rewardFacingTarget ; // Agent should face the target
m_JdController . SetupBodyPart ( bodySegment1 ) ;
m_JdController . SetupBodyPart ( bodySegment2 ) ;
m_JdController . SetupBodyPart ( bodySegment3 ) ;
//We only want the head to detect the target
//So we need to remove TargetContact from everything else
//This is a temp fix till we can redesign
DestroyImmediate ( bodySegment1 . GetComponent < TargetContact > ( ) ) ;
DestroyImmediate ( bodySegment2 . GetComponent < TargetContact > ( ) ) ;
DestroyImmediate ( bodySegment3 . GetComponent < TargetContact > ( ) ) ;
}
{
return ( Quaternion . FromToRotation ( joint . axis , joint . connectedBody . transform . rotation . eulerAngles ) ) ;
return ( Quaternion . FromToRotation ( joint . axis , joint . connectedBody . transform . rotation . eulerAngles ) ) ;
}
/// <summary>
var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix . inverse . MultiplyVector ( rb . velocity ) ;
sensor . AddObservation ( velocityRelativeToLookRotationToTarget ) ;
var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix . inverse . MultiplyVector ( rb . angularVelocity ) ;
var angularVelocityRelativeToLookRotationToTarget =
m_TargetDirMatrix . inverse . MultiplyVector ( rb . angularVelocity ) ;
sensor . AddObservation ( angularVelocityRelativeToLookRotationToTarget ) ;
if ( bp . rb . transform ! = bodySegment0 )
float maxDist = 1 0 ;
if ( Physics . Raycast ( bodySegment0 . position , Vector3 . down , out hit , maxDist ) )
{
sensor . AddObservation ( hit . distance / maxDist ) ;
sensor . AddObservation ( hit . distance / maxDist ) ;
foreach ( var bodyPart in m_JdController . bodyPartsDict . Values )
foreach ( var bodyPart in m_JdController . bodyPartsList )
Quaternion headRotationDeltaFromMatrixRot = Quaternion . Inverse ( m_TargetDirMatrix . rotation ) * bodySegment0 . rotation ;
Quaternion headRotationDeltaFromMatrixRot =
Quaternion . Inverse ( m_TargetDirMatrix . rotation ) * bodySegment0 . rotation ;
sensor . AddObservation ( headRotationDeltaFromMatrixRot ) ;
}
public void TouchedTarget ( )
{
AddReward ( 1f ) ;
if ( respawnTargetWhenTouched )
{
GetRandomTargetPos ( ) ;
}
}
/// <summary>
/// Moves target to a random position within specified radius.
/// </summary>
public void GetRandomTargetPos ( )
{
var newTargetPos = Random . insideUnitSphere * targetSpawnRadius ;
newTargetPos . y = 5 ;
target . position = newTargetPos + ground . position ;
}
public override void OnActionReceived ( float [ ] vectorAction )
bpDict [ bodySegment2 ] . SetJointStrength ( vectorAction [ + + i ] ) ;
bpDict [ bodySegment3 ] . SetJointStrength ( vectorAction [ + + i ] ) ;
if ( bodySegment0 . position . y < ground . position . y - 2 )
// Detect if worm fell off/through platform
if ( bodySegment0 . position . y < ground . position . y - 2 )
{
EndEpisode ( ) ;
}
{
if ( detectTargets )
{
foreach ( var bodyPart in m_JdController . bodyPartsDict . Values )
{
if ( bodyPart . targetContact & & bodyPart . targetContact . touchingTarget )
{
TouchedTarget ( ) ;
}
}
}
// Set reward for this step according to mixture of the following elements.
if ( rewardMovingTowardsTarget )
{
/// </summary>
void RewardFunctionMovingTowards ( )
{
m_MovingTowardsDot = Vector3 . Dot ( m_JdController . bodyPartsDict [ bodySegment0 ] . rb . velocity , m_DirToTarget . normalized ) ;
m_MovingTowardsDot =
Vector3 . Dot ( m_JdController . bodyPartsDict [ bodySegment0 ] . rb . velocity , m_DirToTarget . normalized ) ;
AddReward ( 0.01f * m_MovingTowardsDot ) ;
}
}
/// <summary>
/// Existential penalty for time-contrained tasks.
/// Existential penalty for time-cons trained tasks.
/// </summary>
void RewardFunctionTimePenalty ( )
{
{
bodyPart . Reset ( bodyPart ) ;
}
transform . Rotate ( Vector3 . up , Random . Range ( 0.0f , 3 6 0.0f ) ) ;
if ( ! targetIsStatic )
{
GetRandomTargetPos ( ) ;
}
transform . Rotate ( Vector3 . up , Random . Range ( 0.0f , 3 6 0.0f ) ) ;
}
}