Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

146 行
5.3 KiB

using System.Collections.Generic;
using System.Text.RegularExpressions;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
public class DummySensor : ISensor
{
string m_Name = "DummySensor";
ObservationSpec m_ObservationSpec;
public DummySensor(int dim1)
{
m_ObservationSpec = ObservationSpec.Vector(dim1);
}
public DummySensor(int dim1, int dim2)
{
m_ObservationSpec = ObservationSpec.VariableLength(dim1, dim2);
}
public DummySensor(int dim1, int dim2, int dim3)
{
m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3);
}
public string GetName()
{
return m_Name;
}
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
public byte[] GetCompressedObservation()
{
return null;
}
public int Write(ObservationWriter writer)
{
return this.ObservationSize();
}
public void Update() { }
public void Reset() { }
public CompressionSpec GetCompressionSpec()
{
return CompressionSpec.Default();
}
}
public class SensorShapeValidatorTests
{
[Test]
public void TestShapesAgree()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList2);
}
[Test]
public void TestNumSensorMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), };
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
validator.ValidateSensors(sensorList1);
}
[Test]
public void TestDimensionMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}
[Test]
public void TestSizeMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}
[Test]
public void TestEverythingMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}
}
}