Chris Elion
4 年前
当前提交
4890f5bc
共有 8 个文件被更改,包括 268 次插入 和 13 次删除
-
19com.unity.ml-agents/Runtime/InplaceArray.cs
-
2com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
-
2com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
-
2com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
-
175com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs
-
11com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
-
67com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs
-
3com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta
|
|||
using System; |
|||
using System.Collections; |
|||
using Boo.Lang.Runtime; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents; |
|||
using UnityEngine; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
|
|||
|
|||
[TestFixture] |
|||
public class InplaceArrayTests |
|||
{ |
|||
class LengthCases : IEnumerable |
|||
{ |
|||
public IEnumerator GetEnumerator() |
|||
{ |
|||
yield return 1; |
|||
yield return 2; |
|||
yield return 3; |
|||
yield return 4; |
|||
} |
|||
} |
|||
|
|||
private InplaceArray<int> GetTestArray(int length) |
|||
{ |
|||
switch (length) |
|||
{ |
|||
case 1: |
|||
return new InplaceArray<int>(11); |
|||
case 2: |
|||
return new InplaceArray<int>(11, 22); |
|||
case 3: |
|||
return new InplaceArray<int>(11, 22, 33); |
|||
case 4: |
|||
return new InplaceArray<int>(11, 22, 33, 44); |
|||
default: |
|||
throw new RuntimeException("bad test!"); |
|||
} |
|||
} |
|||
|
|||
private InplaceArray<int> GetZeroArray(int length) |
|||
{ |
|||
switch (length) |
|||
{ |
|||
case 1: |
|||
return new InplaceArray<int>(0); |
|||
case 2: |
|||
return new InplaceArray<int>(0, 0); |
|||
case 3: |
|||
return new InplaceArray<int>(0, 0, 0); |
|||
case 4: |
|||
return new InplaceArray<int>(0, 0, 0, 0); |
|||
default: |
|||
throw new RuntimeException("bad test!"); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestInplaceArrayCtor() |
|||
{ |
|||
var a1 = new InplaceArray<int>(11); |
|||
Assert.AreEqual(1, a1.Length); |
|||
Assert.AreEqual(11, a1[0]); |
|||
|
|||
var a2 = new InplaceArray<int>(11, 22); |
|||
Assert.AreEqual(2, a2.Length); |
|||
Assert.AreEqual(11, a2[0]); |
|||
Assert.AreEqual(22, a2[1]); |
|||
|
|||
var a3 = new InplaceArray<int>(11, 22, 33); |
|||
Assert.AreEqual(3, a3.Length); |
|||
Assert.AreEqual(11, a3[0]); |
|||
Assert.AreEqual(22, a3[1]); |
|||
Assert.AreEqual(33, a3[2]); |
|||
|
|||
var a4 = new InplaceArray<int>(11, 22, 33, 44); |
|||
Assert.AreEqual(4, a4.Length); |
|||
Assert.AreEqual(11, a4[0]); |
|||
Assert.AreEqual(22, a4[1]); |
|||
Assert.AreEqual(33, a4[2]); |
|||
Assert.AreEqual(44, a4[3]); |
|||
} |
|||
|
|||
[TestCaseSource(typeof(LengthCases))] |
|||
public void TestInplaceGetSet(int length) |
|||
{ |
|||
var original = GetTestArray(length); |
|||
|
|||
for (var i = 0; i < original.Length; i++) |
|||
{ |
|||
var modified = original; |
|||
modified[i] = 0; |
|||
for (var j = 0; j < original.Length; j++) |
|||
{ |
|||
if (i == j) |
|||
{ |
|||
// This is the one we overwrote
|
|||
Assert.AreEqual(0, modified[j]); |
|||
} |
|||
else |
|||
{ |
|||
// Other elements should be unchanged
|
|||
Assert.AreEqual(original[j], modified[j]); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
[TestCaseSource(typeof(LengthCases))] |
|||
public void TestInvalidAccess(int length) |
|||
{ |
|||
var tmp = 0; |
|||
var a = GetTestArray(length); |
|||
// get
|
|||
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[-1]; }); |
|||
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[length]; }); |
|||
|
|||
// set
|
|||
Assert.Throws<IndexOutOfRangeException>(() => { a[-1] = 0; }); |
|||
Assert.Throws<IndexOutOfRangeException>(() => { a[length] = 0; }); |
|||
|
|||
// Make sure temp is used
|
|||
Assert.AreEqual(0, tmp); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestOperatorEqualsDifferentLengths() |
|||
{ |
|||
// Check arrays of different length are never equal (even if they have 0s in all elements)
|
|||
for (var l1 = 1; l1 <= 4; l1++) |
|||
{ |
|||
var a1 = GetZeroArray(l1); |
|||
for (var l2 = 1; l2 <= 4; l2++) |
|||
{ |
|||
var a2 = GetZeroArray(l2); |
|||
if (l1 == l2) |
|||
{ |
|||
Assert.AreEqual(a1, a2); |
|||
Assert.IsTrue(a1 == a2); |
|||
} |
|||
else |
|||
{ |
|||
Assert.AreNotEqual(a1, a2); |
|||
Assert.IsTrue(a1 != a2); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
[TestCaseSource(typeof(LengthCases))] |
|||
public void TestOperatorEquals(int length) |
|||
{ |
|||
for (var index = 0; index < length; index++) |
|||
{ |
|||
var a1 = GetTestArray(length); |
|||
var a2 = GetTestArray(length); |
|||
Assert.AreEqual(a1, a2); |
|||
Assert.IsTrue(a1 == a2); |
|||
|
|||
a1[index] = 42; |
|||
Assert.AreNotEqual(a1, a2); |
|||
Assert.IsTrue(a1 != a2); |
|||
|
|||
a2[index] = 42; |
|||
Assert.AreEqual(a1, a2); |
|||
Assert.IsTrue(a1 == a2); |
|||
} |
|||
} |
|||
|
|||
|
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 8e1cdc27e533749fabc04b3cdeb93501 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
[TestFixture] |
|||
public class ObservationSpecTests |
|||
{ |
|||
[Test] |
|||
public void TestVectorObsSpec() |
|||
{ |
|||
var obsSpec = ObservationSpec.Vector(5); |
|||
Assert.AreEqual(1, obsSpec.NumDimensions); |
|||
|
|||
var shape = obsSpec.Shape; |
|||
Assert.AreEqual(1, shape.Length); |
|||
Assert.AreEqual(5, shape[0]); |
|||
|
|||
var dimensionProps = obsSpec.DimensionProperties; |
|||
Assert.AreEqual(1, dimensionProps.Length); |
|||
Assert.AreEqual(DimensionProperty.None, dimensionProps[0]); |
|||
|
|||
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestVariableLengthObsSpec() |
|||
{ |
|||
var obsSpec = ObservationSpec.VariableLength(5, 6); |
|||
Assert.AreEqual(2, obsSpec.NumDimensions); |
|||
|
|||
var shape = obsSpec.Shape; |
|||
Assert.AreEqual(2, shape.Length); |
|||
Assert.AreEqual(5, shape[0]); |
|||
Assert.AreEqual(6, shape[1]); |
|||
|
|||
var dimensionProps = obsSpec.DimensionProperties; |
|||
Assert.AreEqual(2, dimensionProps.Length); |
|||
Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]); |
|||
Assert.AreEqual(DimensionProperty.None, dimensionProps[1]); |
|||
|
|||
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestVisualObsSpec() |
|||
{ |
|||
var obsSpec = ObservationSpec.Visual(5, 6, 7); |
|||
Assert.AreEqual(3, obsSpec.NumDimensions); |
|||
|
|||
var shape = obsSpec.Shape; |
|||
Assert.AreEqual(3, shape.Length); |
|||
Assert.AreEqual(5, shape[0]); |
|||
Assert.AreEqual(6, shape[1]); |
|||
Assert.AreEqual(7, shape[2]); |
|||
|
|||
var dimensionProps = obsSpec.DimensionProperties; |
|||
Assert.AreEqual(3, dimensionProps.Length); |
|||
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]); |
|||
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]); |
|||
Assert.AreEqual(DimensionProperty.None, dimensionProps[2]); |
|||
|
|||
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 27ff1979bd5e4b8ebeb4d98f414a5090 |
|||
timeCreated: 1615863866 |
撰写
预览
正在加载...
取消
保存
Reference in new issue