浏览代码

better logging for NaN rewards (#4205)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
30a5274c
共有 2 个文件被更改,包括 54 次插入4 次删除
  1. 25
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  2. 33
      Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs

25
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


using System;
using Random = UnityEngine.Random;
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class CrawlerAgent : Agent

{
//Add body rotation delta relative to orientation cube
sensor.AddObservation(Quaternion.FromToRotation(body.forward, orientationCube.transform.forward));
//Add pos of target relative to orientation cube
sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position));

{
var movingTowardsDot = Vector3.Dot(orientationCube.transform.forward,
Vector3.ClampMagnitude(m_JdController.bodyPartsDict[body].rb.velocity, maximumWalkingSpeed));
;
if (float.IsNaN(movingTowardsDot))
{
throw new ArgumentException(
"NaN in movingTowardsDot.\n" +
$" orientationCube.transform.forward: {orientationCube.transform.forward}\n"+
$" body.velocity: {m_JdController.bodyPartsDict[body].rb.velocity}\n"+
$" maximumWalkingSpeed: {maximumWalkingSpeed}"
);
}
AddReward(0.03f * movingTowardsDot);
}

void RewardFunctionFacingTarget()
{
AddReward(0.01f * Vector3.Dot(orientationCube.transform.forward, body.forward));
var facingReward = Vector3.Dot(orientationCube.transform.forward, body.forward);
if (float.IsNaN(facingReward))
{
throw new ArgumentException(
"NaN in movingTowardsDot.\n" +
$" orientationCube.transform.forward: {orientationCube.transform.forward}\n"+
$" body.forward: {body.forward}"
);
}
AddReward(0.01f * facingReward);
}
/// <summary>

33
Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


using System;
using MLAgentsExamples;
using UnityEngine;
using Unity.MLAgents;

using Random = UnityEngine.Random;
public class WalkerAgent : Agent
{

// a. Velocity alignment with goal direction.
var moveTowardsTargetReward = Vector3.Dot(cubeForward,
Vector3.ClampMagnitude(m_JdController.bodyPartsDict[hips].rb.velocity, maximumWalkingSpeed));
if (float.IsNaN(moveTowardsTargetReward))
{
throw new ArgumentException(
"NaN in moveTowardsTargetReward.\n" +
$" cubeForward: {cubeForward}\n"+
$" hips.velocity: {m_JdController.bodyPartsDict[hips].rb.velocity}\n"+
$" maximumWalkingSpeed: {maximumWalkingSpeed}"
);
}
if (float.IsNaN(lookAtTargetReward))
{
throw new ArgumentException(
"NaN in lookAtTargetReward.\n" +
$" cubeForward: {cubeForward}\n"+
$" head.forward: {head.forward}"
);
}
var headHeightOverFeetReward =
var headHeightOverFeetReward =
if (float.IsNaN(headHeightOverFeetReward))
{
throw new ArgumentException(
"NaN in headHeightOverFeetReward.\n" +
$" head.position: {head.position}\n"+
$" footL.position: {footL.position}\n"+
$" footR.position: {footR.position}"
);
}
AddReward(
+ 0.02f * moveTowardsTargetReward
+ 0.02f * lookAtTargetReward

正在加载...
取消
保存