|
|
|
|
|
|
using UnityEngine; |
|
|
|
using Unity.MLAgents; |
|
|
|
using Unity.MLAgents.Sensors; |
|
|
|
using System.Collections.Generic; |
|
|
|
|
|
|
|
public class NewReacherAgent : Agent |
|
|
|
{ |
|
|
|
|
|
|
public GameObject goal; |
|
|
|
public GameObject ballType; |
|
|
|
public float ballRange; |
|
|
|
public int ballNumber; |
|
|
|
float m_GoalDegree; |
|
|
|
Rigidbody m_RbA; |
|
|
|
Rigidbody m_RbB; |
|
|
|
|
|
|
float m_Deviation; |
|
|
|
// Frequency of the cosine deviation of the goal along the vertical dimension
|
|
|
|
float m_DeviationFreq; |
|
|
|
|
|
|
|
List<GameObject> m_balls; |
|
|
|
List<float> m_BallDegrees; |
|
|
|
float m_BallSpeed; |
|
|
|
EnvironmentParameters m_ResetParams; |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
m_ResetParams = Academy.Instance.EnvironmentParameters; |
|
|
|
|
|
|
|
SetResetParameters(); |
|
|
|
|
|
|
|
m_balls = new List<GameObject>(); |
|
|
|
m_BallDegrees = new List<float>(); |
|
|
|
CreateBalls(); |
|
|
|
UpdateBallsPosition(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
sensor.AddObservation(m_RbA.velocity); |
|
|
|
|
|
|
|
sensor.AddObservation(m_GoalSpeed); |
|
|
|
|
|
|
|
// irrelevant observations
|
|
|
|
for ( int i = 0; i < ballNumber; i++) |
|
|
|
{ |
|
|
|
sensor.AddObservation(m_balls[i].transform.localPosition); |
|
|
|
} |
|
|
|
sensor.AddObservation(m_BallSpeed); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
{ |
|
|
|
m_GoalDegree += m_GoalSpeed; |
|
|
|
for ( int i = 0; i < ballNumber; i++) |
|
|
|
{ |
|
|
|
m_BallDegrees[i] += m_BallSpeed; |
|
|
|
} |
|
|
|
|
|
|
|
UpdateBallsPosition(); |
|
|
|
|
|
|
|
var torqueX = Mathf.Clamp(vectorAction[0], -1f, 1f) * 150f; |
|
|
|
var torqueZ = Mathf.Clamp(vectorAction[1], -1f, 1f) * 150f; |
|
|
|
|
|
|
m_RbB.AddTorque(new Vector3(torqueX, 0f, torqueZ)); |
|
|
|
|
|
|
|
AddReward( - 0.005f * (vectorAction[0] * vectorAction[0] |
|
|
|
+ vectorAction[1] * vectorAction[1] |
|
|
|
+ vectorAction[2] * vectorAction[2] |
|
|
|
+ vectorAction[3] * vectorAction[3] |
|
|
|
)); |
|
|
|
+ vectorAction[1] * vectorAction[1] |
|
|
|
+ vectorAction[2] * vectorAction[2] |
|
|
|
+ vectorAction[3] * vectorAction[3] |
|
|
|
)); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
{ |
|
|
|
if ((goal.transform.position - hand.transform.position).magnitude > 3.5f) |
|
|
|
{ |
|
|
|
AddReward(-0.001f); |
|
|
|
} |
|
|
|
// AddReward( - 0.001f * (goal.transform.position - hand.transform.position).magnitude);
|
|
|
|
var radians = m_GoalDegree * Mathf.PI / 180f; |
|
|
|
var goalX = 8f * Mathf.Cos(radians); |
|
|
|
var goalY = 8f * Mathf.Sin(radians); |
|
|
|
|
|
|
|
|
|
|
void UpdateBallsPosition() |
|
|
|
{ |
|
|
|
for (int i = 0; i < ballNumber; i++) |
|
|
|
{ |
|
|
|
var radians = m_BallDegrees[i] * Mathf.PI / 180f; |
|
|
|
var ballX = 8f * Mathf.Cos(radians); |
|
|
|
var ballY = 8f * Mathf.Sin(radians); |
|
|
|
var ballZ = m_Deviation * Mathf.Cos(m_DeviationFreq * radians); |
|
|
|
m_balls[i].transform.position = new Vector3(ballY, ballZ, ballX) + transform.position; |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
void CreateBalls() |
|
|
|
{ |
|
|
|
for (int i = 0; i < ballNumber; i++) |
|
|
|
{ |
|
|
|
GameObject b = Instantiate(ballType); |
|
|
|
m_balls.Add(b); |
|
|
|
} |
|
|
|
for (int i=0; i < ballNumber; i++) |
|
|
|
{ |
|
|
|
m_BallDegrees.Add(Random.Range(0, 360)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Resets the position and velocity of the agent and the goal.
|
|
|
|
/// </summary>
|
|
|
|
|
|
|
m_RbB.angularVelocity = Vector3.zero; |
|
|
|
|
|
|
|
m_GoalDegree = Random.Range(0, 360); |
|
|
|
for (int i=0; i < ballNumber; i++) |
|
|
|
{ |
|
|
|
m_BallDegrees[i] = Random.Range(0, 360); |
|
|
|
} |
|
|
|
|
|
|
|
UpdateBallsPosition(); |
|
|
|
|
|
|
|
SetResetParameters(); |
|
|
|
|
|
|
|
|
|
|
m_GoalSpeed = Random.Range(-1f, 1f) * m_ResetParams.GetWithDefault("goal_speed", 1); |
|
|
|
m_Deviation = m_ResetParams.GetWithDefault("deviation", 0); |
|
|
|
m_DeviationFreq = m_ResetParams.GetWithDefault("deviation_freq", 0); |
|
|
|
m_BallSpeed = Random.Range(-0.5f, 0.5f) * m_ResetParams.GetWithDefault("ball_speed", 1); |
|
|
|
} |
|
|
|
} |