Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

242 行
6.2 KiB

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
public class AgentSoccer : Agent
{
// Note that that the detectable tags are different for the blue and purple teams. The order is
// * ball
// * own goal
// * opposing goal
// * wall
// * own teammate
// * opposing player
public enum Team
{
Blue = 0,
Purple = 1
}
public enum Position
{
Striker,
Goalie,
Generic
}
[HideInInspector]
public Team team;
float m_KickPower;
int m_PlayerIndex;
public SoccerFieldArea area;
// The coefficient for the reward for colliding with a ball. Set using curriculum.
float m_BallTouch;
public Position position;
const float k_Power = 2000f;
float m_Existential;
float m_LateralSpeed;
float m_ForwardSpeed;
[HideInInspector]
public float timePenalty;
[HideInInspector]
public Rigidbody agentRb;
SoccerSettings m_SoccerSettings;
BehaviorParameters m_BehaviorParameters;
Vector3 m_Transform;
EnvironmentParameters m_ResetParams;
public override void Initialize()
{
m_Existential = 1f / MaxStep;
m_BehaviorParameters = gameObject.GetComponent<BehaviorParameters>();
if (m_BehaviorParameters.TeamId == (int)Team.Blue)
{
team = Team.Blue;
m_Transform = new Vector3(transform.position.x - 4f, .5f, transform.position.z);
}
else
{
team = Team.Purple;
m_Transform = new Vector3(transform.position.x + 4f, .5f, transform.position.z);
}
if (position == Position.Goalie)
{
m_LateralSpeed = 1.0f;
m_ForwardSpeed = 1.0f;
}
else if (position == Position.Striker)
{
m_LateralSpeed = 0.3f;
m_ForwardSpeed = 1.3f;
}
else
{
m_LateralSpeed = 0.3f;
m_ForwardSpeed = 1.0f;
}
m_SoccerSettings = FindObjectOfType<SoccerSettings>();
agentRb = GetComponent<Rigidbody>();
agentRb.maxAngularVelocity = 500;
var playerState = new PlayerState
{
agentRb = agentRb,
startingPos = transform.position,
agentScript = this,
};
area.playerStates.Add(playerState);
m_PlayerIndex = area.playerStates.IndexOf(playerState);
playerState.playerIndex = m_PlayerIndex;
m_ResetParams = Academy.Instance.EnvironmentParameters;
}
public void MoveAgent(ActionSegment<int> act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;
m_KickPower = 0f;
var forwardAxis = act[0];
var rightAxis = act[1];
var rotateAxis = act[2];
switch (forwardAxis)
{
case 1:
dirToGo = transform.forward * m_ForwardSpeed;
m_KickPower = 1f;
break;
case 2:
dirToGo = transform.forward * -m_ForwardSpeed;
break;
}
switch (rightAxis)
{
case 1:
dirToGo = transform.right * m_LateralSpeed;
break;
case 2:
dirToGo = transform.right * -m_LateralSpeed;
break;
}
switch (rotateAxis)
{
case 1:
rotateDir = transform.up * -1f;
break;
case 2:
rotateDir = transform.up * 1f;
break;
}
transform.Rotate(rotateDir, Time.deltaTime * 100f);
agentRb.AddForce(dirToGo * m_SoccerSettings.agentRunSpeed,
ForceMode.VelocityChange);
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
if (position == Position.Goalie)
{
// Existential bonus for Goalies.
AddReward(m_Existential);
}
else if (position == Position.Striker)
{
// Existential penalty for Strikers
AddReward(-m_Existential);
}
else
{
// Existential penalty cumulant for Generic
timePenalty -= m_Existential;
}
MoveAgent(actionBuffers.DiscreteActions);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut.Clear();
//forward
if (Input.GetKey(KeyCode.W))
{
discreteActionsOut[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
discreteActionsOut[0] = 2;
}
//rotate
if (Input.GetKey(KeyCode.A))
{
discreteActionsOut[2] = 1;
}
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[2] = 2;
}
//right
if (Input.GetKey(KeyCode.E))
{
discreteActionsOut[1] = 1;
}
if (Input.GetKey(KeyCode.Q))
{
discreteActionsOut[1] = 2;
}
}
/// <summary>
/// Used to provide a "kick" to the ball.
/// </summary>
void OnCollisionEnter(Collision c)
{
var force = k_Power * m_KickPower;
if (position == Position.Goalie)
{
force = k_Power;
}
if (c.gameObject.CompareTag("ball"))
{
AddReward(.2f * m_BallTouch);
var dir = c.contacts[0].point - transform.position;
dir = dir.normalized;
c.gameObject.GetComponent<Rigidbody>().AddForce(dir * force);
}
}
public override void OnEpisodeBegin()
{
timePenalty = 0;
m_BallTouch = m_ResetParams.GetWithDefault("ball_touch", 0);
if (team == Team.Purple)
{
transform.rotation = Quaternion.Euler(0f, -90f, 0f);
}
else
{
transform.rotation = Quaternion.Euler(0f, 90f, 0f);
}
transform.position = m_Transform;
agentRb.velocity = Vector3.zero;
agentRb.angularVelocity = Vector3.zero;
SetResetParameters();
}
public void SetResetParameters()
{
area.ResetBall();
}
}