浏览代码

Fix worm environment. (#4337)

* Fix worm environment.

* Remove dead code from WormAgent.cs

* remove target logic. change obsv loop to list instead of dict

Co-authored-by: HH <brandonh@unity3d.com>
/release_6_branch
GitHub 4 年前
当前提交
957010bf
共有 2 个文件被更改,包括 64 次插入58 次删除
  1. 46
      Project/Assets/ML-Agents/Examples/Worm/Prefabs/PlatformWormDynamicTarget.prefab
  2. 76
      Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs

46
Project/Assets/ML-Agents/Examples/Worm/Prefabs/PlatformWormDynamicTarget.prefab


m_Interpolate: 0
m_Constraints: 0
m_CollisionDetection: 0
--- !u!114 &8042564747579887
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 7516457449653310668}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 3c8f113a8b8d94967b1b1782c549be81, type: 3}
m_Name:
m_EditorClassIdentifier:
tagToDetect: agent
spawnRadius: 40
respawnIfTouched: 1
respawnIfFallsOffPlatform: 1
fallDistance: 5
onTriggerEnterEvent:
m_PersistentCalls:
m_Calls: []
onTriggerStayEvent:
m_PersistentCalls:
m_Calls: []
onTriggerExitEvent:
m_PersistentCalls:
m_Calls: []
onCollisionEnterEvent:
m_PersistentCalls:
m_Calls: []
onCollisionStayEvent:
m_PersistentCalls:
m_Calls: []
onCollisionExitEvent:
m_PersistentCalls:
m_Calls: []
--- !u!1001 &906401165941233076
PrefabInstance:
m_ObjectHideFlags: 0

propertyPath: ground
value:
objectReference: {fileID: 7519759559437056804}
- target: {fileID: 6060305997946326746, guid: 3ebcde4cf2d5c4c029e2a5ce3d853aba,
type: 3}
propertyPath: m_TagString
value: agent
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 3ebcde4cf2d5c4c029e2a5ce3d853aba, type: 3}
--- !u!1001 &7202236613889278392

--- !u!4 &7519759559437056804 stripped
Transform:
m_CorrespondingSourceObject: {fileID: 840186797462469276, guid: d6fc96a99a9754f07b48abf1e0d55a5c,
type: 3}
m_PrefabInstance: {fileID: 7202236613889278392}
m_PrefabAsset: {fileID: 0}
--- !u!1 &7516457449653310668 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 845742365997159796, guid: d6fc96a99a9754f07b48abf1e0d55a5c,
type: 3}
m_PrefabInstance: {fileID: 7202236613889278392}
m_PrefabAsset: {fileID: 0}

76
Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs


[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class WormAgent : Agent
{
[Header("Target To Walk Towards")]
[Space(10)]
[Header("Target To Walk Towards")] [Space(10)]
public bool detectTargets;
public bool targetIsStatic;
public bool respawnTargetWhenTouched;
public float targetSpawnRadius;
[Header("Body Parts")] [Space(10)]
public Transform bodySegment0;
[Header("Body Parts")] [Space(10)] public Transform bodySegment0;
[Header("Joint Settings")] [Space(10)]
JointDriveController m_JdController;
[Header("Joint Settings")] [Space(10)] JointDriveController m_JdController;
[Header("Reward Functions To Use")]
[Space(10)]
[Header("Reward Functions To Use")] [Space(10)]
public bool rewardMovingTowardsTarget; // Agent should move towards target
public bool rewardFacingTarget; // Agent should face the target

m_JdController.SetupBodyPart(bodySegment1);
m_JdController.SetupBodyPart(bodySegment2);
m_JdController.SetupBodyPart(bodySegment3);
//We only want the head to detect the target
//So we need to remove TargetContact from everything else
//This is a temp fix till we can redesign
DestroyImmediate(bodySegment1.GetComponent<TargetContact>());
DestroyImmediate(bodySegment2.GetComponent<TargetContact>());
DestroyImmediate(bodySegment3.GetComponent<TargetContact>());
}

{
return(Quaternion.FromToRotation(joint.axis, joint.connectedBody.transform.rotation.eulerAngles));
return (Quaternion.FromToRotation(joint.axis, joint.connectedBody.transform.rotation.eulerAngles));
}
/// <summary>

var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity);
sensor.AddObservation(velocityRelativeToLookRotationToTarget);
var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
var angularVelocityRelativeToLookRotationToTarget =
m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
sensor.AddObservation(angularVelocityRelativeToLookRotationToTarget);
if (bp.rb.transform != bodySegment0)

float maxDist = 10;
if (Physics.Raycast(bodySegment0.position, Vector3.down, out hit, maxDist))
{
sensor.AddObservation(hit.distance/maxDist);
sensor.AddObservation(hit.distance / maxDist);
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
foreach (var bodyPart in m_JdController.bodyPartsList)
Quaternion headRotationDeltaFromMatrixRot = Quaternion.Inverse(m_TargetDirMatrix.rotation) * bodySegment0.rotation;
Quaternion headRotationDeltaFromMatrixRot =
Quaternion.Inverse(m_TargetDirMatrix.rotation) * bodySegment0.rotation;
sensor.AddObservation(headRotationDeltaFromMatrixRot);
}

public void TouchedTarget()
{
AddReward(1f);
if (respawnTargetWhenTouched)
{
GetRandomTargetPos();
}
}
/// <summary>
/// Moves target to a random position within specified radius.
/// </summary>
public void GetRandomTargetPos()
{
var newTargetPos = Random.insideUnitSphere * targetSpawnRadius;
newTargetPos.y = 5;
target.position = newTargetPos + ground.position;
}
public override void OnActionReceived(float[] vectorAction)

bpDict[bodySegment2].SetJointStrength(vectorAction[++i]);
bpDict[bodySegment3].SetJointStrength(vectorAction[++i]);
if (bodySegment0.position.y < ground.position.y -2)
// Detect if worm fell off/through platform
if (bodySegment0.position.y < ground.position.y - 2)
{
EndEpisode();
}

{
if (detectTargets)
{
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
if (bodyPart.targetContact && bodyPart.targetContact.touchingTarget)
{
TouchedTarget();
}
}
}
// Set reward for this step according to mixture of the following elements.
if (rewardMovingTowardsTarget)
{

/// </summary>
void RewardFunctionMovingTowards()
{
m_MovingTowardsDot = Vector3.Dot(m_JdController.bodyPartsDict[bodySegment0].rb.velocity, m_DirToTarget.normalized);
m_MovingTowardsDot =
Vector3.Dot(m_JdController.bodyPartsDict[bodySegment0].rb.velocity, m_DirToTarget.normalized);
AddReward(0.01f * m_MovingTowardsDot);
}

}
/// <summary>
/// Existential penalty for time-contrained tasks.
/// Existential penalty for time-constrained tasks.
/// </summary>
void RewardFunctionTimePenalty()
{

{
bodyPart.Reset(bodyPart);
}
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f));
if (!targetIsStatic)
{
GetRandomTargetPos();
}
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f));
}
}
正在加载...
取消
保存