GitHub
5 年前
当前提交
d8567b82
共有 3 个文件被更改,包括 158 次插入 和 2 次删除
-
4com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs
-
145com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
-
11com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs.meta
|
|||
using System.Collections.Generic; |
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using UnityEngine.TestTools; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class DummySensor : ISensor |
|||
{ |
|||
string m_Name = "DummySensor"; |
|||
int[] m_Shape; |
|||
|
|||
public DummySensor(int dim1) |
|||
{ |
|||
m_Shape = new[] { dim1 }; |
|||
} |
|||
|
|||
public DummySensor(int dim1, int dim2) |
|||
{ |
|||
m_Shape = new[] { dim1, dim2, }; |
|||
} |
|||
|
|||
public DummySensor(int dim1, int dim2, int dim3) |
|||
{ |
|||
m_Shape = new[] { dim1, dim2, dim3}; |
|||
} |
|||
|
|||
public string GetName() |
|||
{ |
|||
return m_Name; |
|||
} |
|||
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
public byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
public int Write(WriteAdapter adapter) |
|||
{ |
|||
return this.ObservationSize(); |
|||
} |
|||
|
|||
public void Update() { } |
|||
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
} |
|||
|
|||
public class SensorShapeValidatorTests |
|||
{ |
|||
|
|||
[Test] |
|||
public void TestShapesAgree() |
|||
{ |
|||
var validator = new SensorShapeValidator(); |
|||
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
|||
validator.ValidateSensors(sensorList1); |
|||
|
|||
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
|||
validator.ValidateSensors(sensorList2); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestNumSensorMismatch() |
|||
{ |
|||
var validator = new SensorShapeValidator(); |
|||
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
|||
validator.ValidateSensors(sensorList1); |
|||
|
|||
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), }; |
|||
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2"); |
|||
validator.ValidateSensors(sensorList2); |
|||
|
|||
// Add the sensors in the other order
|
|||
validator = new SensorShapeValidator(); |
|||
validator.ValidateSensors(sensorList2); |
|||
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); |
|||
validator.ValidateSensors(sensorList1); |
|||
} |
|||
[Test] |
|||
public void TestDimensionMismatch() |
|||
{ |
|||
var validator = new SensorShapeValidator(); |
|||
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
|||
validator.ValidateSensors(sensorList1); |
|||
|
|||
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) }; |
|||
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); |
|||
validator.ValidateSensors(sensorList2); |
|||
|
|||
// Add the sensors in the other order
|
|||
validator = new SensorShapeValidator(); |
|||
validator.ValidateSensors(sensorList2); |
|||
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); |
|||
validator.ValidateSensors(sensorList1); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSizeMismatch() |
|||
{ |
|||
var validator = new SensorShapeValidator(); |
|||
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
|||
validator.ValidateSensors(sensorList1); |
|||
|
|||
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) }; |
|||
LogAssert.Expect(LogType.Assert, "Sensor sizes much match."); |
|||
validator.ValidateSensors(sensorList2); |
|||
|
|||
// Add the sensors in the other order
|
|||
validator = new SensorShapeValidator(); |
|||
validator.ValidateSensors(sensorList2); |
|||
LogAssert.Expect(LogType.Assert, "Sensor sizes much match."); |
|||
validator.ValidateSensors(sensorList1); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEverythingMismatch() |
|||
{ |
|||
var validator = new SensorShapeValidator(); |
|||
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
|||
validator.ValidateSensors(sensorList1); |
|||
|
|||
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) }; |
|||
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2"); |
|||
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); |
|||
LogAssert.Expect(LogType.Assert, "Sensor sizes much match."); |
|||
validator.ValidateSensors(sensorList2); |
|||
|
|||
// Add the sensors in the other order
|
|||
validator = new SensorShapeValidator(); |
|||
validator.ValidateSensors(sensorList2); |
|||
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); |
|||
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match."); |
|||
LogAssert.Expect(LogType.Assert, "Sensor sizes much match."); |
|||
validator.ValidateSensors(sensorList1); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: bbfcd7a9de490454cbc37b8d7d900e7e |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue