using NUnit.Framework; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests { [TestFixture] public class ObservationSpecTests { [Test] public void TestVectorObsSpec() { var obsSpec = ObservationSpec.Vector(5); Assert.AreEqual(1, obsSpec.Rank); var shape = obsSpec.Shape; Assert.AreEqual(1, shape.Length); Assert.AreEqual(5, shape[0]); var dimensionProps = obsSpec.DimensionProperties; Assert.AreEqual(1, dimensionProps.Length); Assert.AreEqual(DimensionProperty.None, dimensionProps[0]); Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); } [Test] public void TestVariableLengthObsSpec() { var obsSpec = ObservationSpec.VariableLength(5, 6); Assert.AreEqual(2, obsSpec.Rank); var shape = obsSpec.Shape; Assert.AreEqual(2, shape.Length); Assert.AreEqual(5, shape[0]); Assert.AreEqual(6, shape[1]); var dimensionProps = obsSpec.DimensionProperties; Assert.AreEqual(2, dimensionProps.Length); Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]); Assert.AreEqual(DimensionProperty.None, dimensionProps[1]); Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); } [Test] public void TestVisualObsSpec() { var obsSpec = ObservationSpec.Visual(5, 6, 7); Assert.AreEqual(3, obsSpec.Rank); var shape = obsSpec.Shape; Assert.AreEqual(3, shape.Length); Assert.AreEqual(5, shape[0]); Assert.AreEqual(6, shape[1]); Assert.AreEqual(7, shape[2]); var dimensionProps = obsSpec.DimensionProperties; Assert.AreEqual(3, dimensionProps.Length); Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]); Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]); Assert.AreEqual(DimensionProperty.None, dimensionProps[2]); Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); } [Test] public void TestMismatchShapeDimensionPropThrows() { var shape = new InplaceArray(1, 2); var dimProps = new InplaceArray(DimensionProperty.TranslationalEquivariance); Assert.Throws(() => { new ObservationSpec(shape, dimProps); }); } } }