浏览代码

Add SensorShapeValidator unit (#3504)

/asymm-envs
GitHub 5 年前
当前提交
d8567b82
共有 3 个文件被更改,包括 158 次插入2 次删除
  1. 4
      com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs
  2. 145
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  3. 11
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs.meta

4
com.unity.ml-agents/Runtime/Sensor/SensorShapeValidator.cs


// Check for compatibility with the other Agents' Sensors
// TODO make sure this only checks once per agent
Debug.Assert(m_SensorShapes.Count == sensors.Count, $"Number of Sensors must match. {m_SensorShapes.Count} != {sensors.Count}");
for (var i = 0; i < m_SensorShapes.Count; i++)
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
for (var j = 0; j < cachedShape.Length; j++)
for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++)
{
Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes much match.");
}

145
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


using System.Collections.Generic;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;
namespace MLAgents.Tests
{
public class DummySensor : ISensor
{
string m_Name = "DummySensor";
int[] m_Shape;
public DummySensor(int dim1)
{
m_Shape = new[] { dim1 };
}
public DummySensor(int dim1, int dim2)
{
m_Shape = new[] { dim1, dim2, };
}
public DummySensor(int dim1, int dim2, int dim3)
{
m_Shape = new[] { dim1, dim2, dim3};
}
public string GetName()
{
return m_Name;
}
public int[] GetObservationShape()
{
return m_Shape;
}
public byte[] GetCompressedObservation()
{
return null;
}
public int Write(WriteAdapter adapter)
{
return this.ObservationSize();
}
public void Update() { }
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
}
public class SensorShapeValidatorTests
{
[Test]
public void TestShapesAgree()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList2);
}
[Test]
public void TestNumSensorMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), };
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
validator.ValidateSensors(sensorList1);
}
[Test]
public void TestDimensionMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
validator.ValidateSensors(sensorList1);
}
[Test]
public void TestSizeMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
validator.ValidateSensors(sensorList1);
}
[Test]
public void TestEverythingMismatch()
{
var validator = new SensorShapeValidator();
var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) };
validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order
validator = new SensorShapeValidator();
validator.ValidateSensors(sensorList2);
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes much match.");
validator.ValidateSensors(sensorList1);
}
}
}

11
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs.meta


fileFormatVersion: 2
guid: bbfcd7a9de490454cbc37b8d7d900e7e
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存