using UnityEngine; using MLAgents; public class BasicAgent : Agent { [Header("Specific to Basic")] public float timeBetweenDecisionsAtInference; float m_TimeSinceDecision; int m_Position; int m_SmallGoalPosition; int m_LargeGoalPosition; public GameObject largeGoal; public GameObject smallGoal; int m_MinPosition; int m_MaxPosition; public override void InitializeAgent() { } public override void CollectObservations(VectorSensor sensor) { sensor.AddOneHotObservation(m_Position, 20); } public override void AgentAction(float[] vectorAction) { var movement = (int)vectorAction[0]; var direction = 0; switch (movement) { case 1: direction = -1; break; case 2: direction = 1; break; } m_Position += direction; if (m_Position < m_MinPosition) { m_Position = m_MinPosition; } if (m_Position > m_MaxPosition) { m_Position = m_MaxPosition; } gameObject.transform.position = new Vector3(m_Position - 10f, 0f, 0f); AddReward(-0.01f); if (m_Position == m_SmallGoalPosition) { AddReward(0.1f); Done(); } if (m_Position == m_LargeGoalPosition) { AddReward(1f); Done(); } } public override void AgentReset() { m_Position = 10; m_MinPosition = 0; m_MaxPosition = 20; m_SmallGoalPosition = 7; m_LargeGoalPosition = 17; smallGoal.transform.position = new Vector3(m_SmallGoalPosition - 10f, 0f, 0f); largeGoal.transform.position = new Vector3(m_LargeGoalPosition - 10f, 0f, 0f); } public override float[] Heuristic() { if (Input.GetKey(KeyCode.D)) { return new float[] { 2 }; } if (Input.GetKey(KeyCode.A)) { return new float[] { 1 }; } return new float[] { 0 }; } public void FixedUpdate() { WaitTimeInference(); } void WaitTimeInference() { if (!Academy.Instance.IsCommunicatorOn) { RequestDecision(); } else { if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference) { m_TimeSinceDecision = 0f; RequestDecision(); } else { m_TimeSinceDecision += Time.fixedDeltaTime; } } } }