您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
78 行
2.6 KiB
78 行
2.6 KiB
using NUnit.Framework;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
[TestFixture]
|
|
public class ObservationSpecTests
|
|
{
|
|
[Test]
|
|
public void TestVectorObsSpec()
|
|
{
|
|
var obsSpec = ObservationSpec.Vector(5);
|
|
Assert.AreEqual(1, obsSpec.Rank);
|
|
|
|
var shape = obsSpec.Shape;
|
|
Assert.AreEqual(1, shape.Length);
|
|
Assert.AreEqual(5, shape[0]);
|
|
|
|
var dimensionProps = obsSpec.DimensionProperties;
|
|
Assert.AreEqual(1, dimensionProps.Length);
|
|
Assert.AreEqual(DimensionProperty.None, dimensionProps[0]);
|
|
|
|
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
|
|
}
|
|
|
|
[Test]
|
|
public void TestVariableLengthObsSpec()
|
|
{
|
|
var obsSpec = ObservationSpec.VariableLength(5, 6);
|
|
Assert.AreEqual(2, obsSpec.Rank);
|
|
|
|
var shape = obsSpec.Shape;
|
|
Assert.AreEqual(2, shape.Length);
|
|
Assert.AreEqual(5, shape[0]);
|
|
Assert.AreEqual(6, shape[1]);
|
|
|
|
var dimensionProps = obsSpec.DimensionProperties;
|
|
Assert.AreEqual(2, dimensionProps.Length);
|
|
Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]);
|
|
Assert.AreEqual(DimensionProperty.None, dimensionProps[1]);
|
|
|
|
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
|
|
}
|
|
|
|
[Test]
|
|
public void TestVisualObsSpec()
|
|
{
|
|
var obsSpec = ObservationSpec.Visual(5, 6, 7);
|
|
Assert.AreEqual(3, obsSpec.Rank);
|
|
|
|
var shape = obsSpec.Shape;
|
|
Assert.AreEqual(3, shape.Length);
|
|
Assert.AreEqual(5, shape[0]);
|
|
Assert.AreEqual(6, shape[1]);
|
|
Assert.AreEqual(7, shape[2]);
|
|
|
|
var dimensionProps = obsSpec.DimensionProperties;
|
|
Assert.AreEqual(3, dimensionProps.Length);
|
|
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]);
|
|
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]);
|
|
Assert.AreEqual(DimensionProperty.None, dimensionProps[2]);
|
|
|
|
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
|
|
}
|
|
|
|
[Test]
|
|
public void TestMismatchShapeDimensionPropThrows()
|
|
{
|
|
var shape = new InplaceArray<int>(1, 2);
|
|
var dimProps = new InplaceArray<DimensionProperty>(DimensionProperty.TranslationalEquivariance);
|
|
Assert.Throws<UnityAgentsException>(() =>
|
|
{
|
|
new ObservationSpec(shape, dimProps);
|
|
});
|
|
}
|
|
}
|
|
}
|