浏览代码

InplaceArray and ObsSpec test

/v2-staging-rebase
Chris Elion 4 年前
当前提交
4890f5bc
共有 8 个文件被更改,包括 268 次插入13 次删除
  1. 19
      com.unity.ml-agents/Runtime/InplaceArray.cs
  2. 2
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  3. 2
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  4. 2
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  5. 175
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs
  6. 11
      com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
  7. 67
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs
  8. 3
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta

19
com.unity.ml-agents/Runtime/InplaceArray.cs


using System;
using System.Collections.Generic;
using System.Linq.Expressions;
namespace Unity.MLAgents
{

private int m_Length;
private readonly int m_Length;
private T m_elem0;
private T m_elem1;

{
get
{
if (index < 0 || index >= k_MaxLength)
if (index < 0 || index >= Length)
throw new ArgumentOutOfRangeException();
throw new IndexOutOfRangeException();
}
switch (index)

case 3:
return m_elem3;
default:
throw new ArgumentOutOfRangeException();
throw new IndexOutOfRangeException();
internal set
set
if (index < 0 || index >= k_MaxLength)
if (index < 0 || index >= Length)
throw new ArgumentOutOfRangeException();
throw new IndexOutOfRangeException();
}
switch (index)

m_elem3 = value;
break;
default:
throw new ArgumentOutOfRangeException();
throw new IndexOutOfRangeException();
}
}
}

case 4:
return $"[{m_elem0}, {m_elem1}, {m_elem2}, {m_elem3}]";
default:
throw new ArgumentOutOfRangeException();
throw new IndexOutOfRangeException();
}
}

2
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


m_ObsSize = obsSize;
m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs];
m_CurrentNumObservables = 0;
m_ObservationSpec = ObservationSpec.VariableSize(m_MaxNumObs, m_ObsSize);
m_ObservationSpec = ObservationSpec.VariableLength(m_MaxNumObs, m_ObsSize);
}
/// <inheritdoc/>

2
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs


return new ObservationSpec(shape, dimProps);
}
public static ObservationSpec VariableSize(int obsSize, int maxNumObs)
public static ObservationSpec VariableLength(int obsSize, int maxNumObs)
{
InplaceArray<int> shape = new InplaceArray<int>(obsSize, maxNumObs);
InplaceArray<DimensionProperty> dimProps = new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None);

2
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


public DummySensor(int dim1, int dim2)
{
m_ObservationSpec = ObservationSpec.VariableSize(dim1, dim2);
m_ObservationSpec = ObservationSpec.VariableLength(dim1, dim2);
}
public DummySensor(int dim1, int dim2, int dim3)

175
com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs


using System;
using System.Collections;
using Boo.Lang.Runtime;
using NUnit.Framework;
using Unity.MLAgents;
using UnityEngine;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class InplaceArrayTests
{
class LengthCases : IEnumerable
{
public IEnumerator GetEnumerator()
{
yield return 1;
yield return 2;
yield return 3;
yield return 4;
}
}
private InplaceArray<int> GetTestArray(int length)
{
switch (length)
{
case 1:
return new InplaceArray<int>(11);
case 2:
return new InplaceArray<int>(11, 22);
case 3:
return new InplaceArray<int>(11, 22, 33);
case 4:
return new InplaceArray<int>(11, 22, 33, 44);
default:
throw new RuntimeException("bad test!");
}
}
private InplaceArray<int> GetZeroArray(int length)
{
switch (length)
{
case 1:
return new InplaceArray<int>(0);
case 2:
return new InplaceArray<int>(0, 0);
case 3:
return new InplaceArray<int>(0, 0, 0);
case 4:
return new InplaceArray<int>(0, 0, 0, 0);
default:
throw new RuntimeException("bad test!");
}
}
[Test]
public void TestInplaceArrayCtor()
{
var a1 = new InplaceArray<int>(11);
Assert.AreEqual(1, a1.Length);
Assert.AreEqual(11, a1[0]);
var a2 = new InplaceArray<int>(11, 22);
Assert.AreEqual(2, a2.Length);
Assert.AreEqual(11, a2[0]);
Assert.AreEqual(22, a2[1]);
var a3 = new InplaceArray<int>(11, 22, 33);
Assert.AreEqual(3, a3.Length);
Assert.AreEqual(11, a3[0]);
Assert.AreEqual(22, a3[1]);
Assert.AreEqual(33, a3[2]);
var a4 = new InplaceArray<int>(11, 22, 33, 44);
Assert.AreEqual(4, a4.Length);
Assert.AreEqual(11, a4[0]);
Assert.AreEqual(22, a4[1]);
Assert.AreEqual(33, a4[2]);
Assert.AreEqual(44, a4[3]);
}
[TestCaseSource(typeof(LengthCases))]
public void TestInplaceGetSet(int length)
{
var original = GetTestArray(length);
for (var i = 0; i < original.Length; i++)
{
var modified = original;
modified[i] = 0;
for (var j = 0; j < original.Length; j++)
{
if (i == j)
{
// This is the one we overwrote
Assert.AreEqual(0, modified[j]);
}
else
{
// Other elements should be unchanged
Assert.AreEqual(original[j], modified[j]);
}
}
}
}
[TestCaseSource(typeof(LengthCases))]
public void TestInvalidAccess(int length)
{
var tmp = 0;
var a = GetTestArray(length);
// get
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[-1]; });
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[length]; });
// set
Assert.Throws<IndexOutOfRangeException>(() => { a[-1] = 0; });
Assert.Throws<IndexOutOfRangeException>(() => { a[length] = 0; });
// Make sure temp is used
Assert.AreEqual(0, tmp);
}
[Test]
public void TestOperatorEqualsDifferentLengths()
{
// Check arrays of different length are never equal (even if they have 0s in all elements)
for (var l1 = 1; l1 <= 4; l1++)
{
var a1 = GetZeroArray(l1);
for (var l2 = 1; l2 <= 4; l2++)
{
var a2 = GetZeroArray(l2);
if (l1 == l2)
{
Assert.AreEqual(a1, a2);
Assert.IsTrue(a1 == a2);
}
else
{
Assert.AreNotEqual(a1, a2);
Assert.IsTrue(a1 != a2);
}
}
}
}
[TestCaseSource(typeof(LengthCases))]
public void TestOperatorEquals(int length)
{
for (var index = 0; index < length; index++)
{
var a1 = GetTestArray(length);
var a2 = GetTestArray(length);
Assert.AreEqual(a1, a2);
Assert.IsTrue(a1 == a2);
a1[index] = 42;
Assert.AreNotEqual(a1, a2);
Assert.IsTrue(a1 != a2);
a2[index] = 42;
Assert.AreEqual(a1, a2);
Assert.IsTrue(a1 == a2);
}
}
}
}

11
com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta


fileFormatVersion: 2
guid: 8e1cdc27e533749fabc04b3cdeb93501
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

67
com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs


using NUnit.Framework;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class ObservationSpecTests
{
[Test]
public void TestVectorObsSpec()
{
var obsSpec = ObservationSpec.Vector(5);
Assert.AreEqual(1, obsSpec.NumDimensions);
var shape = obsSpec.Shape;
Assert.AreEqual(1, shape.Length);
Assert.AreEqual(5, shape[0]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(1, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.None, dimensionProps[0]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestVariableLengthObsSpec()
{
var obsSpec = ObservationSpec.VariableLength(5, 6);
Assert.AreEqual(2, obsSpec.NumDimensions);
var shape = obsSpec.Shape;
Assert.AreEqual(2, shape.Length);
Assert.AreEqual(5, shape[0]);
Assert.AreEqual(6, shape[1]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(2, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]);
Assert.AreEqual(DimensionProperty.None, dimensionProps[1]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
[Test]
public void TestVisualObsSpec()
{
var obsSpec = ObservationSpec.Visual(5, 6, 7);
Assert.AreEqual(3, obsSpec.NumDimensions);
var shape = obsSpec.Shape;
Assert.AreEqual(3, shape.Length);
Assert.AreEqual(5, shape[0]);
Assert.AreEqual(6, shape[1]);
Assert.AreEqual(7, shape[2]);
var dimensionProps = obsSpec.DimensionProperties;
Assert.AreEqual(3, dimensionProps.Length);
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]);
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]);
Assert.AreEqual(DimensionProperty.None, dimensionProps[2]);
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType);
}
}
}

3
com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta


fileFormatVersion: 2
guid: 27ff1979bd5e4b8ebeb4d98f414a5090
timeCreated: 1615863866
正在加载...
取消
保存