您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
152 行
5.4 KiB
152 行
5.4 KiB
using System.Collections.Generic;
|
|
using NUnit.Framework;
|
|
using UnityEngine;
|
|
using UnityEngine.TestTools;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class DummySensor : ISensor
|
|
{
|
|
string m_Name = "DummySensor";
|
|
int[] m_Shape;
|
|
|
|
public DummySensor(int dim1)
|
|
{
|
|
m_Shape = new[] { dim1 };
|
|
}
|
|
|
|
public DummySensor(int dim1, int dim2)
|
|
{
|
|
m_Shape = new[] { dim1, dim2, };
|
|
}
|
|
|
|
public DummySensor(int dim1, int dim2, int dim3)
|
|
{
|
|
m_Shape = new[] { dim1, dim2, dim3};
|
|
}
|
|
|
|
public string GetName()
|
|
{
|
|
return m_Name;
|
|
}
|
|
|
|
public int[] GetObservationShape()
|
|
{
|
|
return m_Shape;
|
|
}
|
|
|
|
public byte[] GetCompressedObservation()
|
|
{
|
|
return null;
|
|
}
|
|
|
|
public int Write(ObservationWriter writer)
|
|
{
|
|
return this.ObservationSize();
|
|
}
|
|
|
|
public void Update() { }
|
|
public void Reset() { }
|
|
|
|
public SensorCompressionType GetCompressionType()
|
|
{
|
|
return SensorCompressionType.None;
|
|
}
|
|
|
|
public SensorType GetSensorType()
|
|
{
|
|
return SensorType.Observation;
|
|
}
|
|
}
|
|
|
|
public class SensorShapeValidatorTests
|
|
{
|
|
|
|
[Test]
|
|
public void TestShapesAgree()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList2);
|
|
}
|
|
|
|
[Test]
|
|
public void TestNumSensorMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), };
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
[Test]
|
|
public void TestDimensionMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
|
|
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
|
|
[Test]
|
|
public void TestSizeMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
|
|
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
|
|
[Test]
|
|
public void TestEverythingMismatch()
|
|
{
|
|
var validator = new SensorShapeValidator();
|
|
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
|
|
validator.ValidateSensors(sensorList1);
|
|
|
|
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
|
|
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
|
|
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
|
|
validator.ValidateSensors(sensorList2);
|
|
|
|
// Add the sensors in the other order
|
|
validator = new SensorShapeValidator();
|
|
validator.ValidateSensors(sensorList2);
|
|
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
|
|
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
|
|
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
|
|
validator.ValidateSensors(sensorList1);
|
|
}
|
|
}
|
|
}
|