Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

215 行
5.9 KiB

using UnityEngine;
using MLAgents;
public class AgentSoccer : Agent
{
public enum Team
{
Purple,
Blue
}
public enum AgentRole
{
Striker,
Goalie
}
public Team team;
public AgentRole agentRole;
float m_KickPower;
int m_PlayerIndex;
public SoccerFieldArea area;
[HideInInspector]
public Rigidbody agentRb;
SoccerAcademy m_Academy;
Renderer m_AgentRenderer;
RayPerception m_RayPer;
float[] m_RayAngles = { 0f, 45f, 90f, 135f, 180f, 110f, 70f };
string[] m_DetectableObjectsPurple = { "ball", "purpleGoal", "blueGoal",
"wall", "purpleAgent", "blueAgent" };
string[] m_DetectableObjectsBlue = { "ball", "blueGoal", "purpleGoal",
"wall", "blueAgent", "purpleAgent" };
public void ChooseRandomTeam()
{
team = (Team)Random.Range(0, 2);
if (team == Team.Purple)
{
JoinPurpleTeam(agentRole);
}
else
{
JoinBlueTeam(agentRole);
}
}
public void JoinPurpleTeam(AgentRole role)
{
agentRole = role;
team = Team.Purple;
m_AgentRenderer.material = m_Academy.purpleMaterial;
tag = "purpleAgent";
}
public void JoinBlueTeam(AgentRole role)
{
agentRole = role;
team = Team.Blue;
m_AgentRenderer.material = m_Academy.blueMaterial;
tag = "blueAgent";
}
public override void InitializeAgent()
{
base.InitializeAgent();
m_AgentRenderer = GetComponentInChildren<Renderer>();
m_RayPer = GetComponent<RayPerception>();
m_Academy = FindObjectOfType<SoccerAcademy>();
agentRb = GetComponent<Rigidbody>();
agentRb.maxAngularVelocity = 500;
var playerState = new PlayerState
{
agentRb = agentRb,
startingPos = transform.position,
agentScript = this,
};
area.playerStates.Add(playerState);
m_PlayerIndex = area.playerStates.IndexOf(playerState);
playerState.playerIndex = m_PlayerIndex;
}
public override void CollectObservations()
{
var rayDistance = 20f;
string[] detectableObjects;
if (team == Team.Purple)
{
detectableObjects = m_DetectableObjectsPurple;
}
else
{
detectableObjects = m_DetectableObjectsBlue;
}
AddVectorObs(m_RayPer.Perceive(rayDistance, m_RayAngles, detectableObjects));
AddVectorObs(m_RayPer.Perceive(rayDistance, m_RayAngles, detectableObjects, 1f, 1f));
}
public void MoveAgent(float[] act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;
var action = Mathf.FloorToInt(act[0]);
// Goalies and Strikers have slightly different action spaces.
if (agentRole == AgentRole.Goalie)
{
m_KickPower = 0f;
switch (action)
{
case 1:
dirToGo = transform.forward * 1f;
m_KickPower = 1f;
break;
case 2:
dirToGo = transform.forward * -1f;
break;
case 4:
dirToGo = transform.right * -1f;
break;
case 3:
dirToGo = transform.right * 1f;
break;
}
}
else
{
m_KickPower = 0f;
switch (action)
{
case 1:
dirToGo = transform.forward * 1f;
m_KickPower = 1f;
break;
case 2:
dirToGo = transform.forward * -1f;
break;
case 3:
rotateDir = transform.up * 1f;
break;
case 4:
rotateDir = transform.up * -1f;
break;
case 5:
dirToGo = transform.right * -0.75f;
break;
case 6:
dirToGo = transform.right * 0.75f;
break;
}
}
transform.Rotate(rotateDir, Time.deltaTime * 100f);
agentRb.AddForce(dirToGo * m_Academy.agentRunSpeed,
ForceMode.VelocityChange);
}
public override void AgentAction(float[] vectorAction, string textAction)
{
// Existential penalty for strikers.
if (agentRole == AgentRole.Striker)
{
AddReward(-1f / 3000f);
}
// Existential bonus for goalies.
if (agentRole == AgentRole.Goalie)
{
AddReward(1f / 3000f);
}
MoveAgent(vectorAction);
}
/// <summary>
/// Used to provide a "kick" to the ball.
/// </summary>
void OnCollisionEnter(Collision c)
{
var force = 2000f * m_KickPower;
if (c.gameObject.CompareTag("ball"))
{
var dir = c.contacts[0].point - transform.position;
dir = dir.normalized;
c.gameObject.GetComponent<Rigidbody>().AddForce(dir * force);
}
}
public override void AgentReset()
{
if (m_Academy.randomizePlayersTeamForTraining)
{
ChooseRandomTeam();
}
if (team == Team.Purple)
{
JoinPurpleTeam(agentRole);
transform.rotation = Quaternion.Euler(0f, -90f, 0f);
}
else
{
JoinBlueTeam(agentRole);
transform.rotation = Quaternion.Euler(0f, 90f, 0f);
}
transform.position = area.GetRandomSpawnPos(agentRole, team);
agentRb.velocity = Vector3.zero;
agentRb.angularVelocity = Vector3.zero;
SetResetParameters();
}
public void SetResetParameters()
{
area.ResetBall();
}
}