using System.Collections; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; public class HallwayAgent : Agent { public GameObject ground; public GameObject area; public GameObject symbolOGoal; public GameObject symbolXGoal; public GameObject symbolO; public GameObject symbolX; public bool useVectorObs; protected Rigidbody m_AgentRb; protected Material m_GroundMaterial; protected Renderer m_GroundRenderer; protected HallwaySettings m_HallwaySettings; protected int m_Selection; StatsRecorder m_statsRecorder; public override void Initialize() { m_HallwaySettings = FindObjectOfType(); m_AgentRb = GetComponent(); m_GroundRenderer = ground.GetComponent(); m_GroundMaterial = m_GroundRenderer.material; // m_statsRecorder = Academy.Instance.StatsRecorder; } public override void CollectObservations(VectorSensor sensor) { if (useVectorObs) { sensor.AddObservation(StepCount / (float)MaxStep); } } protected IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time) { m_GroundRenderer.material = mat; yield return new WaitForSeconds(time); m_GroundRenderer.material = m_GroundMaterial; } public void MoveAgent(ActionSegment act) { var dirToGo = Vector3.zero; var rotateDir = Vector3.zero; var action = act[0]; switch (action) { case 1: dirToGo = transform.forward * 1f; break; case 2: dirToGo = transform.forward * -1f; break; case 3: rotateDir = transform.up * 1f; break; case 4: rotateDir = transform.up * -1f; break; } transform.Rotate(rotateDir, Time.deltaTime * 150f); m_AgentRb.AddForce(dirToGo * m_HallwaySettings.agentRunSpeed, ForceMode.VelocityChange); } public override void OnActionReceived(ActionBuffers actionBuffers) { AddReward(-1f / MaxStep); MoveAgent(actionBuffers.DiscreteActions); } void OnCollisionEnter(Collision col) { if (col.gameObject.CompareTag("symbol_O_Goal") || col.gameObject.CompareTag("symbol_X_Goal")) { if ((m_Selection == 0 && col.gameObject.CompareTag("symbol_O_Goal")) || (m_Selection == 1 && col.gameObject.CompareTag("symbol_X_Goal"))) { SetReward(1f); StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.goalScoredMaterial, 0.5f)); // m_statsRecorder.Add("Goal/Correct", 1, StatAggregationMethod.Sum); } else { SetReward(-0.1f); StartCoroutine(GoalScoredSwapGroundMaterial(m_HallwaySettings.failMaterial, 0.5f)); // m_statsRecorder.Add("Goal/Wrong", 1, StatAggregationMethod.Sum); } EndEpisode(); } } public override void Heuristic(in ActionBuffers actionsOut) { var discreteActionsOut = actionsOut.DiscreteActions; discreteActionsOut[0] = 0; if (Input.GetKey(KeyCode.D)) { discreteActionsOut[0] = 3; } else if (Input.GetKey(KeyCode.W)) { discreteActionsOut[0] = 1; } else if (Input.GetKey(KeyCode.A)) { discreteActionsOut[0] = 4; } else if (Input.GetKey(KeyCode.S)) { discreteActionsOut[0] = 2; } } public override void OnEpisodeBegin() { var agentOffset = -15f; var blockOffset = 0f; m_Selection = Random.Range(0, 2); if (m_Selection == 0) { symbolO.transform.position = new Vector3(0f + Random.Range(-3f, 3f), 2f, blockOffset + Random.Range(-5f, 5f)) + ground.transform.position; symbolX.transform.position = new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f)) + ground.transform.position; } else { symbolO.transform.position = new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f)) + ground.transform.position; symbolX.transform.position = new Vector3(0f, 2f, blockOffset + Random.Range(-5f, 5f)) + ground.transform.position; } transform.position = new Vector3(0f + Random.Range(-3f, 3f), 1f, agentOffset + Random.Range(-5f, 5f)) + ground.transform.position; transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f); m_AgentRb.velocity *= 0f; var goalPos = Random.Range(0, 2); if (goalPos == 0) { symbolOGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position; symbolXGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position; } else { symbolXGoal.transform.position = new Vector3(7f, 0.5f, 22.29f) + area.transform.position; symbolOGoal.transform.position = new Vector3(-7f, 0.5f, 22.29f) + area.transform.position; } // m_statsRecorder.Add("Goal/Correct", 0, StatAggregationMethod.Sum); // m_statsRecorder.Add("Goal/Wrong", 0, StatAggregationMethod.Sum); } }