using UnityEngine;
using UnityEngine.SceneManagement;
using Unity.MLAgents;
using UnityEngine.Serialization;
///
/// An example of how to use ML-Agents without inheriting from the Agent class.
/// Observations are generated by the attached SensorComponent, and the actions
/// are retrieved from the Agent.
///
public class BasicController : MonoBehaviour
{
public float timeBetweenDecisionsAtInference;
float m_TimeSinceDecision;
[FormerlySerializedAs("m_Position")]
[HideInInspector]
public int position;
const int k_SmallGoalPosition = 7;
const int k_LargeGoalPosition = 17;
public GameObject largeGoal;
public GameObject smallGoal;
const int k_MinPosition = 0;
const int k_MaxPosition = 20;
public const int k_Extents = k_MaxPosition - k_MinPosition;
Agent m_Agent;
public void OnEnable()
{
m_Agent = GetComponent();
position = 10;
transform.position = new Vector3(position - 10f, 0f, 0f);
smallGoal.transform.position = new Vector3(k_SmallGoalPosition - 10f, 0f, 0f);
largeGoal.transform.position = new Vector3(k_LargeGoalPosition - 10f, 0f, 0f);
}
///
/// Controls the movement of the GameObject based on the actions received.
///
///
public void ApplyAction(float[] vectorAction)
{
var movement = (int)vectorAction[0];
var direction = 0;
switch (movement)
{
case 1:
direction = -1;
break;
case 2:
direction = 1;
break;
}
position += direction;
if (position < k_MinPosition) { position = k_MinPosition; }
if (position > k_MaxPosition) { position = k_MaxPosition; }
gameObject.transform.position = new Vector3(position - 10f, 0f, 0f);
m_Agent.AddReward(-0.01f);
if (position == k_SmallGoalPosition)
{
m_Agent.AddReward(0.1f);
m_Agent.EndEpisode();
ResetAgent();
}
if (position == k_LargeGoalPosition)
{
m_Agent.AddReward(1f);
m_Agent.EndEpisode();
ResetAgent();
}
}
public void ResetAgent()
{
// This is a very inefficient way to reset the scene. Used here for testing.
SceneManager.LoadScene(SceneManager.GetActiveScene().name);
m_Agent = null; // LoadScene only takes effect at the next Update.
// We set the Agent to null to avoid using the Agent before the reload
}
public void FixedUpdate()
{
WaitTimeInference();
}
void WaitTimeInference()
{
if (m_Agent == null)
{
return;
}
if (Academy.Instance.IsCommunicatorOn)
{
// Apply the previous step's actions
ApplyAction(m_Agent.GetAction());
m_Agent?.RequestDecision();
}
else
{
if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference)
{
// Apply the previous step's actions
ApplyAction(m_Agent.GetAction());
m_TimeSinceDecision = 0f;
m_Agent?.RequestDecision();
}
else
{
m_TimeSinceDecision += Time.fixedDeltaTime;
}
}
}
}