Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

149 行
4.9 KiB

using UnityEngine;
using UnityEngine.UI;
using MLAgents;
using MLAgents.Sensors;
using MLAgents.SideChannels;
public class TennisAgent : Agent
{
[Header("Specific to Tennis")]
public GameObject ball;
public bool invertX;
public int score;
public GameObject myArea;
public float angle;
public float scale;
[HideInInspector]
public float timePenalty = 0;
Text m_TextComponent;
Rigidbody m_AgentRb;
Rigidbody m_BallRb;
TennisArea m_Area;
float m_InvertMult;
FloatPropertiesChannel m_ResetParams;
float m_BallTouch;
// Looks for the scoreboard based on the name of the gameObjects.
// Do not modify the names of the Score GameObjects
const string k_CanvasName = "Canvas";
const string k_ScoreBoardAName = "ScoreA";
const string k_ScoreBoardBName = "ScoreB";
public override void Initialize()
{
m_AgentRb = GetComponent<Rigidbody>();
m_BallRb = ball.GetComponent<Rigidbody>();
m_Area = myArea.GetComponent<TennisArea>();
var canvas = GameObject.Find(k_CanvasName);
GameObject scoreBoard;
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
if (invertX)
{
scoreBoard = canvas.transform.Find(k_ScoreBoardBName).gameObject;
}
else
{
scoreBoard = canvas.transform.Find(k_ScoreBoardAName).gameObject;
}
m_TextComponent = scoreBoard.GetComponent<Text>();
SetResetParameters();
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(m_InvertMult * (transform.position.x - myArea.transform.position.x));
sensor.AddObservation(transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_AgentRb.velocity.x);
sensor.AddObservation(m_AgentRb.velocity.y);
sensor.AddObservation(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
sensor.AddObservation(ball.transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_BallRb.velocity.x);
sensor.AddObservation(m_BallRb.velocity.y);
sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z);
}
public override void OnActionReceived(float[] vectorAction)
{
var moveX = Mathf.Clamp(vectorAction[0], -1f, 1f) * m_InvertMult;
var moveY = Mathf.Clamp(vectorAction[1], -1f, 1f);
var rotate = Mathf.Clamp(vectorAction[2], -1f, 1f) * m_InvertMult;
if (moveY > 0.0)// && transform.position.y - transform.parent.transform.position.y < -1.5f)
{
m_AgentRb.velocity = new Vector3(m_AgentRb.velocity.x, moveY * 20f, 0f);
}
m_AgentRb.velocity = new Vector3(moveX * 30f, m_AgentRb.velocity.y, 0f);
m_AgentRb.transform.rotation = Quaternion.Euler(0f, -180f, 55f * rotate + m_InvertMult * 90f);
if (invertX && transform.position.x - transform.parent.transform.position.x < -m_InvertMult ||
!invertX && transform.position.x - transform.parent.transform.position.x > -m_InvertMult)
{
transform.position = new Vector3(-m_InvertMult + transform.parent.transform.position.x,
transform.position.y,
transform.position.z);
}
timePenalty += -1f / 3000f;
m_TextComponent.text = score.ToString();
}
public override float[] Heuristic()
{
var action = new float[3];
action[0] = Input.GetAxis("Horizontal"); // Racket Movement
action[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
action[2] = Input.GetAxis("Vertical"); // Racket Rotation
return action;
}
void OnCollisionEnter(Collision c)
{
if (c.gameObject.CompareTag("ball"))
{
AddReward(.2f * m_BallTouch);
}
}
public override void OnEpisodeBegin()
{
timePenalty = 0;
m_BallTouch = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("ball_touch", 0);
m_InvertMult = invertX ? -1f : 1f;
if (m_InvertMult == 1f)
{
m_Area.MatchReset();
}
transform.position = new Vector3(-m_InvertMult * 14f, 5f, -1.8f) + transform.parent.transform.position;
m_AgentRb.velocity = new Vector3(0f, 0f, 0f);
SetResetParameters();
}
public void SetRacket()
{
angle = m_ResetParams.GetPropertyWithDefault("angle", 55);
gameObject.transform.eulerAngles = new Vector3(
gameObject.transform.eulerAngles.x,
gameObject.transform.eulerAngles.y,
m_InvertMult * angle
);
}
public void SetBall()
{
scale = m_ResetParams.GetPropertyWithDefault("scale", .5f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}
public void SetResetParameters()
{
SetRacket();
SetBall();
}
}