Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

117 行
3.9 KiB

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using MLAgents;
public class TennisAgent : Agent
{
[Header("Specific to Tennis")]
public GameObject ball;
public bool invertX;
public int score;
public GameObject myArea;
public float angle;
public float scale;
private Text textComponent;
private Rigidbody agentRb;
private Rigidbody ballRb;
private float invertMult;
private ResetParameters resetParams;
// Looks for the scoreboard based on the name of the gameObjects.
// Do not modify the names of the Score GameObjects
private const string CanvasName = "Canvas";
private const string ScoreBoardAName = "ScoreA";
private const string ScoreBoardBName = "ScoreB";
public override void InitializeAgent()
{
agentRb = GetComponent<Rigidbody>();
ballRb = ball.GetComponent<Rigidbody>();
var canvas = GameObject.Find(CanvasName);
GameObject scoreBoard;
var academy = Object.FindObjectOfType<Academy>() as Academy;
resetParams = academy.resetParameters;
if (invertX)
{
scoreBoard = canvas.transform.Find(ScoreBoardBName).gameObject;
}
else
{
scoreBoard = canvas.transform.Find(ScoreBoardAName).gameObject;
}
textComponent = scoreBoard.GetComponent<Text>();
SetResetParameters();
}
public override void CollectObservations()
{
AddVectorObs(invertMult * (transform.position.x - myArea.transform.position.x));
AddVectorObs(transform.position.y - myArea.transform.position.y);
AddVectorObs(invertMult * agentRb.velocity.x);
AddVectorObs(agentRb.velocity.y);
AddVectorObs(invertMult * (ball.transform.position.x - myArea.transform.position.x));
AddVectorObs(ball.transform.position.y - myArea.transform.position.y);
AddVectorObs(invertMult * ballRb.velocity.x);
AddVectorObs(ballRb.velocity.y);
}
public override void AgentAction(float[] vectorAction, string textAction)
{
var moveX = Mathf.Clamp(vectorAction[0], -1f, 1f) * invertMult;
var moveY = Mathf.Clamp(vectorAction[1], -1f, 1f);
if (moveY > 0.5 && transform.position.y - transform.parent.transform.position.y < -1.5f)
{
agentRb.velocity = new Vector3(agentRb.velocity.x, 7f, 0f);
}
agentRb.velocity = new Vector3(moveX * 30f, agentRb.velocity.y, 0f);
if (invertX && transform.position.x - transform.parent.transform.position.x < -invertMult ||
!invertX && transform.position.x - transform.parent.transform.position.x > -invertMult)
{
transform.position = new Vector3(-invertMult + transform.parent.transform.position.x,
transform.position.y,
transform.position.z);
}
textComponent.text = score.ToString();
}
public override void AgentReset()
{
invertMult = invertX ? -1f : 1f;
transform.position = new Vector3(-invertMult * Random.Range(6f, 8f), -1.5f, 0f) + transform.parent.transform.position;
agentRb.velocity = new Vector3(0f, 0f, 0f);
SetResetParameters();
}
public void SetRacket()
{
angle = resetParams["angle"];
gameObject.transform.eulerAngles = new Vector3(
gameObject.transform.eulerAngles.x,
gameObject.transform.eulerAngles.y,
invertMult * angle
);
}
public void SetBall()
{
scale = resetParams["scale"];
ball.transform.localScale = new Vector3(scale, scale, scale);
}
public void SetResetParameters()
{
SetRacket();
SetBall();
}
}