using NUnit.Framework; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests { public class StackingSensorTests { [Test] public void TestCtor() { ISensor wrapped = new VectorSensor(4); ISensor sensor = new StackingSensor(wrapped, 4); Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName()); Assert.AreEqual(sensor.GetObservationShape(), new[] {16}); } [Test] public void TestStacking() { VectorSensor wrapped = new VectorSensor(2); ISensor sensor = new StackingSensor(wrapped, 3); wrapped.AddObservation(new[] {1f, 2f}); SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 0f, 0f, 1f, 2f}); sensor.Update(); wrapped.AddObservation(new[] {3f, 4f}); SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 1f, 2f, 3f, 4f}); sensor.Update(); wrapped.AddObservation(new[] {5f, 6f}); SensorTestHelper.CompareObservation(sensor, new[] {1f, 2f, 3f, 4f, 5f, 6f}); sensor.Update(); wrapped.AddObservation(new[] {7f, 8f}); SensorTestHelper.CompareObservation(sensor, new[] {3f, 4f, 5f, 6f, 7f, 8f}); sensor.Update(); wrapped.AddObservation(new[] {9f, 10f}); SensorTestHelper.CompareObservation(sensor, new[] {5f, 6f, 7f, 8f, 9f, 10f}); // Check that if we don't call Update(), the same observations are produced SensorTestHelper.CompareObservation(sensor, new[] {5f, 6f, 7f, 8f, 9f, 10f}); } [Test] public void TestStackingReset() { VectorSensor wrapped = new VectorSensor(2); ISensor sensor = new StackingSensor(wrapped, 3); wrapped.AddObservation(new[] {1f, 2f}); SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 0f, 0f, 1f, 2f}); sensor.Update(); wrapped.AddObservation(new[] {3f, 4f}); SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 1f, 2f, 3f, 4f}); sensor.Reset(); wrapped.AddObservation(new[] {5f, 6f}); SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 0f, 0f, 5f, 6f}); } } }