using System; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using Random = UnityEngine.Random; public class FoodCollectorAgent : Agent { FoodCollectorSettings m_FoodCollecterSettings; public GameObject area; FoodCollectorArea m_MyArea; bool m_Frozen; bool m_Poisoned; bool m_Satiated; bool m_Shoot; float m_FrozenTime; float m_EffectTime; Rigidbody m_AgentRb; float m_LaserLength; // Speed of agent rotation. public float turnSpeed = 300; // Speed of agent movement. public float moveSpeed = 2; public Material normalMaterial; public Material badMaterial; public Material goodMaterial; public Material frozenMaterial; public GameObject myLaser; public bool contribute; public bool useVectorObs; [Tooltip("Use only the frozen flag in vector observations. If \"Use Vector Obs\" " + "is checked, this option has no effect. This option is necessary for the " + "VisualFoodCollector scene.")] public bool useVectorFrozenFlag; EnvironmentParameters m_ResetParams; public override void Initialize() { m_AgentRb = GetComponent(); m_MyArea = area.GetComponent(); m_FoodCollecterSettings = FindObjectOfType(); m_ResetParams = Academy.Instance.EnvironmentParameters; SetResetParameters(); } public override void CollectObservations(VectorSensor sensor) { if (useVectorObs) { var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity); sensor.AddObservation(localVelocity.x); sensor.AddObservation(localVelocity.z); sensor.AddObservation(m_Frozen); sensor.AddObservation(m_Shoot); } else if (useVectorFrozenFlag) { sensor.AddObservation(m_Frozen); } } public Color32 ToColor(int hexVal) { var r = (byte)((hexVal >> 16) & 0xFF); var g = (byte)((hexVal >> 8) & 0xFF); var b = (byte)(hexVal & 0xFF); return new Color32(r, g, b, 255); } public void MoveAgent(ActionSegment act) { m_Shoot = false; if (Time.time > m_FrozenTime + 4f && m_Frozen) { Unfreeze(); } if (Time.time > m_EffectTime + 0.5f) { if (m_Poisoned) { Unpoison(); } if (m_Satiated) { Unsatiate(); } } var dirToGo = Vector3.zero; var rotateDir = Vector3.zero; if (!m_Frozen) { var shootCommand = false; var forwardAxis = (int)act[0]; var rightAxis = (int)act[1]; var rotateAxis = (int)act[2]; var shootAxis = (int)act[3]; switch (forwardAxis) { case 1: dirToGo = transform.forward; break; case 2: dirToGo = -transform.forward; break; } switch (rightAxis) { case 1: dirToGo = transform.right; break; case 2: dirToGo = -transform.right; break; } switch (rotateAxis) { case 1: rotateDir = -transform.up; break; case 2: rotateDir = transform.up; break; } switch (shootAxis) { case 1: shootCommand = true; break; } if (shootCommand) { m_Shoot = true; dirToGo *= 0.5f; m_AgentRb.velocity *= 0.75f; } m_AgentRb.AddForce(dirToGo * moveSpeed, ForceMode.VelocityChange); transform.Rotate(rotateDir, Time.fixedDeltaTime * turnSpeed); } if (m_AgentRb.velocity.sqrMagnitude > 25f) // slow it down { m_AgentRb.velocity *= 0.95f; } if (m_Shoot) { var myTransform = transform; myLaser.transform.localScale = new Vector3(1f, 1f, m_LaserLength); var rayDir = 25.0f * myTransform.forward; Debug.DrawRay(myTransform.position, rayDir, Color.red, 0f, true); RaycastHit hit; if (Physics.SphereCast(transform.position, 2f, rayDir, out hit, 25f)) { if (hit.collider.gameObject.CompareTag("agent")) { hit.collider.gameObject.GetComponent().Freeze(); } } } else { myLaser.transform.localScale = new Vector3(0f, 0f, 0f); } } void Freeze() { gameObject.tag = "frozenAgent"; m_Frozen = true; m_FrozenTime = Time.time; gameObject.GetComponentInChildren().material = frozenMaterial; } void Unfreeze() { m_Frozen = false; gameObject.tag = "agent"; gameObject.GetComponentInChildren().material = normalMaterial; } void Poison() { m_Poisoned = true; m_EffectTime = Time.time; gameObject.GetComponentInChildren().material = badMaterial; } void Unpoison() { m_Poisoned = false; gameObject.GetComponentInChildren().material = normalMaterial; } void Satiate() { m_Satiated = true; m_EffectTime = Time.time; gameObject.GetComponentInChildren().material = goodMaterial; } void Unsatiate() { m_Satiated = false; gameObject.GetComponentInChildren().material = normalMaterial; } public override void OnActionReceived(ActionBuffers actionBuffers) { MoveAgent(actionBuffers.DiscreteActions); } public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; discreteActionsOut[0] = 0; discreteActionsOut[1] = 0; discreteActionsOut[2] = 0; if (Input.GetKey(KeyCode.D)) { discreteActionsOut[2] = 2; } if (Input.GetKey(KeyCode.W)) { discreteActionsOut[0] = 1; } if (Input.GetKey(KeyCode.A)) { discreteActionsOut[2] = 1; } if (Input.GetKey(KeyCode.S)) { discreteActionsOut[0] = 2; } discreteActionsOut[3] = Input.GetKey(KeyCode.Space) ? 1 : 0; } public override void OnEpisodeBegin() { Unfreeze(); Unpoison(); Unsatiate(); m_Shoot = false; m_AgentRb.velocity = Vector3.zero; myLaser.transform.localScale = new Vector3(0f, 0f, 0f); transform.position = new Vector3(Random.Range(-m_MyArea.range, m_MyArea.range), 2f, Random.Range(-m_MyArea.range, m_MyArea.range)) + area.transform.position; transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360))); SetResetParameters(); } void OnCollisionEnter(Collision collision) { if (collision.gameObject.CompareTag("food")) { Satiate(); collision.gameObject.GetComponent().OnEaten(); AddReward(1f); if (contribute) { m_FoodCollecterSettings.totalScore += 1; } } if (collision.gameObject.CompareTag("badFood")) { Poison(); collision.gameObject.GetComponent().OnEaten(); AddReward(-1f); if (contribute) { m_FoodCollecterSettings.totalScore -= 1; } } } public void SetLaserLengths() { m_LaserLength = m_ResetParams.GetWithDefault("laser_length", 1.0f); } public void SetAgentScale() { float agentScale = m_ResetParams.GetWithDefault("agent_scale", 1.0f); gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale); } public void SetResetParameters() { SetLaserLengths(); SetAgentScale(); } }