浏览代码

reacher new reward: action penalty and constant not-reaching-goal penalty

/develop/bisim-sac-transfer
yanchaosun 5 年前
当前提交
883361ee
共有 7 个文件被更改,包括 79 次插入13 次删除
  1. 2
      Project/Assets/ML-Agents/Examples/Reacher/Prefabs/Agent.prefab
  2. 72
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/NewReacherAgent.cs
  3. 2
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/NewReacherGoal.cs
  4. 8
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
  5. 2
      Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherGoal.cs
  6. 2
      config/sac_transfer/Reacher.yaml
  7. 4
      config/sac_transfer/ReacherTransfer.yaml

2
Project/Assets/ML-Agents/Examples/Reacher/Prefabs/Agent.prefab


VectorActionSize: 04000000
VectorActionDescriptions: []
VectorActionSpaceType: 1
m_Model: {fileID: 11400000, guid: e12acd64209f9468c899b9708b2702c3, type: 3}
m_Model: {fileID: 11400000, guid: d7bdb6a78154f4cf99437d67e4a569a8, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
m_BehaviorName: Reacher

72
Project/Assets/ML-Agents/Examples/Reacher/Scripts/NewReacherAgent.cs


using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using System.Collections.Generic;
public class NewReacherAgent : Agent
{

public GameObject goal;
public GameObject ballType;
public float ballRange;
public int ballNumber;
float m_GoalDegree;
Rigidbody m_RbA;
Rigidbody m_RbB;

float m_Deviation;
// Frequency of the cosine deviation of the goal along the vertical dimension
float m_DeviationFreq;
List<GameObject> m_balls;
List<float> m_BallDegrees;
float m_BallSpeed;
EnvironmentParameters m_ResetParams;
/// <summary>

m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
m_balls = new List<GameObject>();
m_BallDegrees = new List<float>();
CreateBalls();
UpdateBallsPosition();
}
/// <summary>

sensor.AddObservation(m_RbA.velocity);
sensor.AddObservation(m_GoalSpeed);
// irrelevant observations
for ( int i = 0; i < ballNumber; i++)
{
sensor.AddObservation(m_balls[i].transform.localPosition);
}
sensor.AddObservation(m_BallSpeed);
}
/// <summary>

{
m_GoalDegree += m_GoalSpeed;
for ( int i = 0; i < ballNumber; i++)
{
m_BallDegrees[i] += m_BallSpeed;
}
UpdateBallsPosition();
var torqueX = Mathf.Clamp(vectorAction[0], -1f, 1f) * 150f;
var torqueZ = Mathf.Clamp(vectorAction[1], -1f, 1f) * 150f;

m_RbB.AddTorque(new Vector3(torqueX, 0f, torqueZ));
AddReward( - 0.005f * (vectorAction[0] * vectorAction[0]
+ vectorAction[1] * vectorAction[1]
+ vectorAction[2] * vectorAction[2]
+ vectorAction[3] * vectorAction[3]
));
+ vectorAction[1] * vectorAction[1]
+ vectorAction[2] * vectorAction[2]
+ vectorAction[3] * vectorAction[3]
));
}
/// <summary>

{
if ((goal.transform.position - hand.transform.position).magnitude > 3.5f)
{
AddReward(-0.001f);
}
// AddReward( - 0.001f * (goal.transform.position - hand.transform.position).magnitude);
var radians = m_GoalDegree * Mathf.PI / 180f;
var goalX = 8f * Mathf.Cos(radians);
var goalY = 8f * Mathf.Sin(radians);

void UpdateBallsPosition()
{
for (int i = 0; i < ballNumber; i++)
{
var radians = m_BallDegrees[i] * Mathf.PI / 180f;
var ballX = 8f * Mathf.Cos(radians);
var ballY = 8f * Mathf.Sin(radians);
var ballZ = m_Deviation * Mathf.Cos(m_DeviationFreq * radians);
m_balls[i].transform.position = new Vector3(ballY, ballZ, ballX) + transform.position;
}
}
void CreateBalls()
{
for (int i = 0; i < ballNumber; i++)
{
GameObject b = Instantiate(ballType);
m_balls.Add(b);
}
for (int i=0; i < ballNumber; i++)
{
m_BallDegrees.Add(Random.Range(0, 360));
}
}
/// <summary>
/// Resets the position and velocity of the agent and the goal.
/// </summary>

m_RbB.angularVelocity = Vector3.zero;
m_GoalDegree = Random.Range(0, 360);
for (int i=0; i < ballNumber; i++)
{
m_BallDegrees[i] = Random.Range(0, 360);
}
UpdateBallsPosition();
SetResetParameters();

m_GoalSpeed = Random.Range(-1f, 1f) * m_ResetParams.GetWithDefault("goal_speed", 1);
m_Deviation = m_ResetParams.GetWithDefault("deviation", 0);
m_DeviationFreq = m_ResetParams.GetWithDefault("deviation_freq", 0);
m_BallSpeed = Random.Range(-0.5f, 0.5f) * m_ResetParams.GetWithDefault("ball_speed", 1);
}
}

2
Project/Assets/ML-Agents/Examples/Reacher/Scripts/NewReacherGoal.cs


{
if (other.gameObject == hand)
{
agent.GetComponent<NewReacherAgent>().AddReward(0.05f);
agent.GetComponent<NewReacherAgent>().AddReward(0.01f);
}
}
}

8
Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs


/// </summary>
void UpdateGoalPosition()
{
AddReward( - 0.001f * (goal.transform.position - hand.transform.position).magnitude);
// Debug.Log( - 0.001f * (goal.transform.position - hand.transform.position).magnitude);
if ((goal.transform.position - hand.transform.position).magnitude > 3.5f)
{
AddReward(-0.001f);
}
// AddReward( - 0.001f * (goal.transform.position - hand.transform.position).magnitude);
// Debug.Log((goal.transform.position - hand.transform.position).magnitude);
var radians = m_GoalDegree * Mathf.PI / 180f;
var goalX = 8f * Mathf.Cos(radians);
var goalY = 8f * Mathf.Sin(radians);

2
Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherGoal.cs


{
if (other.gameObject == hand)
{
agent.GetComponent<ReacherAgent>().AddReward(0.05f);
agent.GetComponent<ReacherAgent>().AddReward(0.01f);
}
}
}

2
config/sac_transfer/Reacher.yaml


reward_signal_steps_per_update: 20.0
encoder_layers: 1
policy_layers: 2
forward_layers: 1
forward_layers: 0
value_layers: 2
action_layers: 1
feature_size: 64

4
config/sac_transfer/ReacherTransfer.yaml


Reacher:
trainer_type: sac_transfer
hyperparameters:
learning_rate: 0.0003
learning_rate: 0.0006
learning_rate_schedule: constant
model_schedule: constant
batch_size: 256

reward_signal_steps_per_update: 20.0
encoder_layers: 1
policy_layers: 2
forward_layers: 1
forward_layers: 0
value_layers: 2
action_layers: 1
feature_size: 64

正在加载...
取消
保存