您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
146 行
5.3 KiB
146 行
5.3 KiB
using System.Collections.Generic;
|
|
using System.Text.RegularExpressions;
|
|
using NUnit.Framework;
|
|
using UnityEngine;
|
|
using UnityEngine.TestTools;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class DummySensor : ISensor
|
|
{
|
|
string m_Name = "DummySensor";
|
|
ObservationSpec m_ObservationSpec;
|
|
|
|
public DummySensor(int dim1)
|
|
{
|
|
m_ObservationSpec = ObservationSpec.Vector(dim1);
|
|
}
|
|
|
|
public DummySensor(int dim1, int dim2)
|
|
{
|
|
m_ObservationSpec = ObservationSpec.VariableSize(dim1, dim2);
|
|
}
|
|
|
|
public DummySensor(int dim1, int dim2, int dim3)
|
|
{
|
|
m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3);
|
|
}
|
|
|
|
public string GetName()
|
|
{
|
|
return m_Name;
|
|
}
|
|
|
|
public ObservationSpec GetObservationSpec()
|
|
{
|
|
return m_ObservationSpec;
|
|
}
|
|
|
|
public byte[] GetCompressedObservation()
|
|
{
|
|
return null;
|
|
}
|
|
|
|
public int Write(ObservationWriter writer)
|
|
{
|
|
return this.ObservationSize();
|
|
}
|
|
|
|
public void Update() { }
|
|
public void Reset() { }
|
|
|
|
public SensorCompressionType GetCompressionType()
|
|
{
|
|
return SensorCompressionType.None;
|
|
}
|
|
}
|
|
|
|
public class SensorShapeValidatorTests
|
|
{
|
|
|
|
[Test]
|
|
public void TestShapesAgree()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList2);
|
|
}
|
|
|
|
[Test]
|
|
public void TestNumSensorMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), };
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
[Test]
|
|
public void TestDimensionMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
|
|
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
|
|
[Test]
|
|
public void TestSizeMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
|
|
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
|
|
[Test]
|
|
public void TestEverythingMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
|
|
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
|
|
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
}
|
|
}
|