using System; using UnityEngine; using MLAgents; using MLAgents.Policies; using MLAgents.SideChannels; public class AgentSoccer : Agent { // Note that that the detectable tags are different for the blue and purple teams. The order is // * ball // * own goal // * opposing goal // * wall // * own teammate // * opposing player public enum Team { Blue = 0, Purple = 1 } public enum Position { Striker, Goalie, Generic } [HideInInspector] public Team team; float m_KickPower; int m_PlayerIndex; public SoccerFieldArea area; // The coefficient for the reward for colliding with a ball. Set using curriculum. float m_BallTouch; public Position position; const float k_Power = 2000f; float m_Existential; float m_LateralSpeed; float m_ForwardSpeed; [HideInInspector] public float timePenalty = 0; [HideInInspector] public Rigidbody agentRb; SoccerSettings m_SoccerSettings; BehaviorParameters m_BehaviorParameters; Vector3 m_Transform; public override void Initialize() { m_Existential = 1f / maxStep; m_BehaviorParameters = gameObject.GetComponent(); if (m_BehaviorParameters.TeamId == (int)Team.Blue) { team = Team.Blue; m_Transform = new Vector3(transform.position.x - 4f, .5f, transform.position.z); } else { team = Team.Purple; m_Transform = new Vector3(transform.position.x + 4f, .5f, transform.position.z); } if (position == Position.Goalie) { m_LateralSpeed = 1.0f; m_ForwardSpeed = 1.0f; } else if (position == Position.Striker) { m_LateralSpeed = 0.3f; m_ForwardSpeed = 1.3f; } else { m_LateralSpeed = 0.3f; m_ForwardSpeed = 1.0f; } m_SoccerSettings = FindObjectOfType(); agentRb = GetComponent(); agentRb.maxAngularVelocity = 500; var playerState = new PlayerState { agentRb = agentRb, startingPos = transform.position, agentScript = this, }; area.playerStates.Add(playerState); m_PlayerIndex = area.playerStates.IndexOf(playerState); playerState.playerIndex = m_PlayerIndex; } public void MoveAgent(float[] act) { var dirToGo = Vector3.zero; var rotateDir = Vector3.zero; m_KickPower = 0f; var forwardAxis = (int)act[0]; var rightAxis = (int)act[1]; var rotateAxis = (int)act[2]; switch (forwardAxis) { case 1: dirToGo = transform.forward * m_ForwardSpeed; m_KickPower = 1f; break; case 2: dirToGo = transform.forward * -m_ForwardSpeed; break; } switch (rightAxis) { case 1: dirToGo = transform.right * m_LateralSpeed; break; case 2: dirToGo = transform.right * -m_LateralSpeed; break; } switch (rotateAxis) { case 1: rotateDir = transform.up * -1f; break; case 2: rotateDir = transform.up * 1f; break; } transform.Rotate(rotateDir, Time.deltaTime * 100f); agentRb.AddForce(dirToGo * m_SoccerSettings.agentRunSpeed, ForceMode.VelocityChange); } public override void OnActionReceived(float[] vectorAction) { if (position == Position.Goalie) { // Existential bonus for Goalies. AddReward(m_Existential); } else if (position == Position.Striker) { // Existential penalty for Strikers AddReward(-m_Existential); } else { // Existential penalty cumulant for Generic timePenalty -= m_Existential; } MoveAgent(vectorAction); } public override void Heuristic(float[] actionsOut) { //forward if (Input.GetKey(KeyCode.W)) { actionsOut[0] = 1f; } if (Input.GetKey(KeyCode.S)) { actionsOut[0] = 2f; } //rotate if (Input.GetKey(KeyCode.A)) { actionsOut[2] = 1f; } if (Input.GetKey(KeyCode.D)) { actionsOut[2] = 2f; } //right if (Input.GetKey(KeyCode.E)) { actionsOut[1] = 1f; } if (Input.GetKey(KeyCode.Q)) { actionsOut[1] = 2f; } } /// /// Used to provide a "kick" to the ball. /// void OnCollisionEnter(Collision c) { var force = k_Power * m_KickPower; if (position == Position.Goalie) { force = k_Power; } if (c.gameObject.CompareTag("ball")) { AddReward(.2f * m_BallTouch); var dir = c.contacts[0].point - transform.position; dir = dir.normalized; c.gameObject.GetComponent().AddForce(dir * force); } } public override void OnEpisodeBegin() { timePenalty = 0; m_BallTouch = SideChannelUtils.GetSideChannel().GetPropertyWithDefault("ball_touch", 0); if (team == Team.Purple) { transform.rotation = Quaternion.Euler(0f, -90f, 0f); } else { transform.rotation = Quaternion.Euler(0f, 90f, 0f); } transform.position = m_Transform; agentRb.velocity = Vector3.zero; agentRb.angularVelocity = Vector3.zero; SetResetParameters(); } public void SetResetParameters() { area.ResetBall(); } }