using NUnit.Framework; using Unity.MLAgents.Actuators; using Unity.MLAgents.Policies; using UnityEngine; namespace Unity.MLAgents.Tests.Policies { [TestFixture] public class HeuristicPolicyTest { [SetUp] public void SetUp() { if (Academy.IsInitialized) { Academy.Instance.Dispose(); } } /// /// Assert that the action buffers are initialized to zero, and then set them to non-zero values. /// /// static void CheckAndSetBuffer(in ActionBuffers actionsOut) { var continuousActions = actionsOut.ContinuousActions; for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++) { Assert.AreEqual(continuousActions[continuousIndex], 0.0f); continuousActions[continuousIndex] = 1.0f; } var discreteActions = actionsOut.DiscreteActions; for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++) { Assert.AreEqual(discreteActions[discreteIndex], 0); discreteActions[discreteIndex] = 1; } } class ActionClearedAgent : Agent { public int HeuristicCalls = 0; public override void Heuristic(in ActionBuffers actionsOut) { CheckAndSetBuffer(actionsOut); HeuristicCalls++; } } class ActionClearedActuator : IActuator { public int HeuristicCalls = 0; public ActionClearedActuator(ActionSpec actionSpec) { ActionSpec = actionSpec; Name = GetType().Name; } public void OnActionReceived(ActionBuffers actionBuffers) { } public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { } public void Heuristic(in ActionBuffers actionBuffersOut) { CheckAndSetBuffer(actionBuffersOut); HeuristicCalls++; } public ActionSpec ActionSpec { get; } public string Name { get; } public void ResetData() { } } class ActionClearedActuatorComponent : ActuatorComponent { public ActionClearedActuator ActionClearedActuator; public ActionClearedActuatorComponent() { ActionSpec = new ActionSpec(2, new[] { 3, 3 }); } public override IActuator[] CreateActuators() { ActionClearedActuator = new ActionClearedActuator(ActionSpec); return new IActuator[] { ActionClearedActuator }; } public override ActionSpec ActionSpec { get; } } [Test] public void TestActionsCleared() { var gameObj = new GameObject(); var agent = gameObj.AddComponent(); var behaviorParameters = agent.GetComponent(); behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 }); behaviorParameters.BrainParameters.VectorObservationSize = 0; behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly; var actuatorComponent = gameObj.AddComponent(); agent.LazyInitialize(); const int k_NumSteps = 5; for (var i = 0; i < k_NumSteps; i++) { agent.RequestDecision(); Academy.Instance.EnvironmentStep(); } Assert.AreEqual(agent.HeuristicCalls, k_NumSteps); Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps); } } }