using System; using System.Collections.Generic; using System.Linq; using NUnit.Framework; using Unity.MLAgents.Actuators; using Assert = UnityEngine.Assertions.Assert; namespace Unity.MLAgents.Tests.Actuators { [TestFixture] public class VectorActuatorTests { class TestActionReceiver : IActionReceiver, IHeuristicProvider { public ActionBuffers LastActionBuffers; public int Branch; public IList Mask; public ActionSpec ActionSpec { get; } public bool HeuristicCalled; public void OnActionReceived(ActionBuffers actionBuffers) { LastActionBuffers = actionBuffers; } public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { foreach (var actionIndex in Mask) { actionMask.SetActionEnabled(Branch, actionIndex, false); } } public void Heuristic(in ActionBuffers actionBuffersOut) { HeuristicCalled = true; } } [Test] public void TestConstruct() { var ar = new TestActionReceiver(); var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3); Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6); Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0); var va1 = new VectorActuator(ar, ActionSpec.MakeContinuous(4), "name"); Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4); Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0); Assert.AreEqual(va1.Name, "name-Continuous"); } [Test] public void TestOnActionReceived() { var ar = new TestActionReceiver(); var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); var discreteActions = new[] { 0, 1, 1 }; var ab = new ActionBuffers(ActionSegment.Empty, new ActionSegment(discreteActions, 0, 3)); va.OnActionReceived(ab); Assert.AreEqual(ar.LastActionBuffers, ab); va.ResetData(); Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment.Empty); Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment.Empty); } [Test] public void TestResetData() { var ar = new TestActionReceiver(); var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); var discreteActions = new[] { 0, 1, 1 }; var ab = new ActionBuffers(ActionSegment.Empty, new ActionSegment(discreteActions, 0, 3)); va.OnActionReceived(ab); } [Test] public void TestWriteDiscreteActionMask() { var ar = new TestActionReceiver(); var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); var bdam = new ActuatorDiscreteActionMask(new[] { va }, 6, 3); var groundTruthMask = new[] { false, true, false, false, true, true }; ar.Branch = 1; ar.Mask = new[] { 0 }; va.WriteDiscreteActionMask(bdam); ar.Branch = 2; ar.Mask = new[] { 1, 2 }; va.WriteDiscreteActionMask(bdam); Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask())); } [Test] public void TestHeuristic() { var ar = new TestActionReceiver(); var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); va.Heuristic(new ActionBuffers(Array.Empty(), va.ActionSpec.BranchSizes)); Assert.IsTrue(ar.HeuristicCalled); } } }