浏览代码
[MLA-1634] Add ObservationSpec and update ISensor interfaces (#5127)
/goal-conditioning/sensors-3-pytest-fix
[MLA-1634] Add ObservationSpec and update ISensor interfaces (#5127)
/goal-conditioning/sensors-3-pytest-fix
Christopher Goy
4 年前
当前提交
113c1bca
共有 55 个文件被更改,包括 1009 次插入 和 374 次删除
-
2DevProject/Packages/manifest.json
-
2DevProject/Packages/packages-lock.json
-
18DevProject/ProjectSettings/Packages/com.unity.testtools.codecoverage/Settings.json
-
4DevProject/ProjectSettings/ProjectVersion.txt
-
4Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
-
5Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta
-
8Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
-
12com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
-
30com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
-
10com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
12com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
-
8com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs
-
6com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs
-
10com.unity.ml-agents/CHANGELOG.md
-
1com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
-
7com.unity.ml-agents/Runtime/Analytics/Events.cs
-
55com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
-
26com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
-
9com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
-
2com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
-
4com.unity.ml-agents/Runtime/SensorHelper.cs
-
24com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
-
32com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
-
70com.unity.ml-agents/Runtime/Sensors/ISensor.cs
-
13com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
-
8com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
-
12com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
-
8com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
-
21com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
-
48com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
-
10com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
-
23com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
-
19com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
-
4com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
-
6com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs
-
3com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
-
13com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
-
2com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
-
20com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs
-
2com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
-
27com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
-
30com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
-
2com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
-
241com.unity.ml-agents/Runtime/InplaceArray.cs
-
3com.unity.ml-agents/Runtime/InplaceArray.cs.meta
-
140com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
-
3com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs.meta
-
192com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs
-
78com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs
-
3com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs.meta
-
31com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/ITypedSensor.cs.meta
-
47com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs
-
0/com.unity.ml-agents/Tests/Editor/InplaceArrayTests.cs.meta
|
|||
m_EditorVersion: 2019.4.19f1 |
|||
m_EditorVersionWithRevision: 2019.4.19f1 (ca5b14067cec) |
|||
m_EditorVersion: 2019.4.20f1 |
|||
m_EditorVersionWithRevision: 2019.4.20f1 (6dd1c08eedfa) |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// An array-like object that stores up to four elements.
|
|||
/// This is a value type that does not allocate any additional memory.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// This does not implement any interfaces such as IList, in order to avoid any accidental boxing allocations.
|
|||
/// </remarks>
|
|||
/// <typeparam name="T"></typeparam>
|
|||
public struct InplaceArray<T> : IEquatable<InplaceArray<T>> where T : struct |
|||
{ |
|||
private const int k_MaxLength = 4; |
|||
private readonly int m_Length; |
|||
|
|||
private T m_Elem0; |
|||
private T m_Elem1; |
|||
private T m_Elem2; |
|||
private T m_Elem3; |
|||
|
|||
/// <summary>
|
|||
/// Create a length-1 array.
|
|||
/// </summary>
|
|||
/// <param name="elem0"></param>
|
|||
public InplaceArray(T elem0) |
|||
{ |
|||
m_Length = 1; |
|||
m_Elem0 = elem0; |
|||
m_Elem1 = new T { }; |
|||
m_Elem2 = new T { }; |
|||
m_Elem3 = new T { }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Create a length-2 array.
|
|||
/// </summary>
|
|||
/// <param name="elem0"></param>
|
|||
/// <param name="elem1"></param>
|
|||
public InplaceArray(T elem0, T elem1) |
|||
{ |
|||
m_Length = 2; |
|||
m_Elem0 = elem0; |
|||
m_Elem1 = elem1; |
|||
m_Elem2 = new T { }; |
|||
m_Elem3 = new T { }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Create a length-3 array.
|
|||
/// </summary>
|
|||
/// <param name="elem0"></param>
|
|||
/// <param name="elem1"></param>
|
|||
/// <param name="elem2"></param>
|
|||
public InplaceArray(T elem0, T elem1, T elem2) |
|||
{ |
|||
m_Length = 3; |
|||
m_Elem0 = elem0; |
|||
m_Elem1 = elem1; |
|||
m_Elem2 = elem2; |
|||
m_Elem3 = new T { }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Create a length-3 array.
|
|||
/// </summary>
|
|||
/// <param name="elem0"></param>
|
|||
/// <param name="elem1"></param>
|
|||
/// <param name="elem2"></param>
|
|||
/// <param name="elem3"></param>
|
|||
public InplaceArray(T elem0, T elem1, T elem2, T elem3) |
|||
{ |
|||
m_Length = 4; |
|||
m_Elem0 = elem0; |
|||
m_Elem1 = elem1; |
|||
m_Elem2 = elem2; |
|||
m_Elem3 = elem3; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Construct an InplaceArray from an IList (e.g. Array or List).
|
|||
/// The source must be non-empty and have at most 4 elements.
|
|||
/// </summary>
|
|||
/// <param name="elems"></param>
|
|||
/// <returns></returns>
|
|||
/// <exception cref="ArgumentOutOfRangeException"></exception>
|
|||
public static InplaceArray<T> FromList(IList<T> elems) |
|||
{ |
|||
switch (elems.Count) |
|||
{ |
|||
case 1: |
|||
return new InplaceArray<T>(elems[0]); |
|||
case 2: |
|||
return new InplaceArray<T>(elems[0], elems[1]); |
|||
case 3: |
|||
return new InplaceArray<T>(elems[0], elems[1], elems[2]); |
|||
case 4: |
|||
return new InplaceArray<T>(elems[0], elems[1], elems[2], elems[3]); |
|||
default: |
|||
throw new ArgumentOutOfRangeException(); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Per-element access.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <exception cref="IndexOutOfRangeException"></exception>
|
|||
public T this[int index] |
|||
{ |
|||
get |
|||
{ |
|||
if (index >= Length) |
|||
{ |
|||
throw new IndexOutOfRangeException(); |
|||
} |
|||
|
|||
switch (index) |
|||
{ |
|||
case 0: |
|||
return m_Elem0; |
|||
case 1: |
|||
return m_Elem1; |
|||
case 2: |
|||
return m_Elem2; |
|||
case 3: |
|||
return m_Elem3; |
|||
default: |
|||
throw new IndexOutOfRangeException(); |
|||
} |
|||
} |
|||
|
|||
set |
|||
{ |
|||
if (index >= Length) |
|||
{ |
|||
throw new IndexOutOfRangeException(); |
|||
} |
|||
|
|||
switch (index) |
|||
{ |
|||
case 0: |
|||
m_Elem0 = value; |
|||
break; |
|||
case 1: |
|||
m_Elem1 = value; |
|||
break; |
|||
case 2: |
|||
m_Elem2 = value; |
|||
break; |
|||
case 3: |
|||
m_Elem3 = value; |
|||
break; |
|||
default: |
|||
throw new IndexOutOfRangeException(); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// The length of the array.
|
|||
/// </summary>
|
|||
public int Length |
|||
{ |
|||
get => m_Length; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns a string representation of the array's elements.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
/// <exception cref="IndexOutOfRangeException"></exception>
|
|||
public override string ToString() |
|||
{ |
|||
switch (m_Length) |
|||
{ |
|||
case 1: |
|||
return $"[{m_Elem0}]"; |
|||
case 2: |
|||
return $"[{m_Elem0}, {m_Elem1}]"; |
|||
case 3: |
|||
return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}]"; |
|||
case 4: |
|||
return $"[{m_Elem0}, {m_Elem1}, {m_Elem2}, {m_Elem3}]"; |
|||
default: |
|||
throw new IndexOutOfRangeException(); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Check that the arrays have the same length and have all equal values.
|
|||
/// </summary>
|
|||
/// <param name="lhs"></param>
|
|||
/// <param name="rhs"></param>
|
|||
/// <returns>Whether the arrays are equivalent.</returns>
|
|||
public static bool operator ==(InplaceArray<T> lhs, InplaceArray<T> rhs) |
|||
{ |
|||
return lhs.Equals(rhs); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Check that the arrays are not equivalent.
|
|||
/// </summary>
|
|||
/// <param name="lhs"></param>
|
|||
/// <param name="rhs"></param>
|
|||
/// <returns>Whether the arrays are not equivalent</returns>
|
|||
public static bool operator !=(InplaceArray<T> lhs, InplaceArray<T> rhs) => !lhs.Equals(rhs); |
|||
|
|||
/// <summary>
|
|||
/// Check that the arrays are equivalent.
|
|||
/// </summary>
|
|||
/// <param name="other"></param>
|
|||
/// <returns>Whether the arrays are not equivalent</returns>
|
|||
public override bool Equals(object other) => other is InplaceArray<T> other1 && this.Equals(other1); |
|||
|
|||
/// <summary>
|
|||
/// Check that the arrays are equivalent.
|
|||
/// </summary>
|
|||
/// <param name="other"></param>
|
|||
/// <returns>Whether the arrays are not equivalent</returns>
|
|||
public bool Equals(InplaceArray<T> other) |
|||
{ |
|||
// See https://montemagno.com/optimizing-c-struct-equality-with-iequatable/
|
|||
var thisTuple = (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length); |
|||
var otherTuple = (other.m_Elem0, other.m_Elem1, other.m_Elem2, other.m_Elem3, other.Length); |
|||
return thisTuple.Equals(otherTuple); |
|||
|
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get a hashcode for the array.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override int GetHashCode() |
|||
{ |
|||
return (m_Elem0, m_Elem1, m_Elem2, m_Elem3, Length).GetHashCode(); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: c1a80abee18a41c8aee89aeb33f5985d |
|||
timeCreated: 1615506199 |
|
|||
using Unity.Barracuda; |
|||
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// A description of the observations that an ISensor produces.
|
|||
/// This includes the size of the observation, the properties of each dimension, and how the observation
|
|||
/// should be used for training.
|
|||
/// </summary>
|
|||
public struct ObservationSpec |
|||
{ |
|||
internal readonly InplaceArray<int> m_Shape; |
|||
|
|||
/// <summary>
|
|||
/// The size of the observations that will be generated.
|
|||
/// For example, a sensor that observes the velocity of a rigid body (in 3D) would use [3].
|
|||
/// A sensor that returns an RGB image would use [Height, Width, 3].
|
|||
/// </summary>
|
|||
public InplaceArray<int> Shape |
|||
{ |
|||
get => m_Shape; |
|||
} |
|||
|
|||
internal readonly InplaceArray<DimensionProperty> m_DimensionProperties; |
|||
|
|||
/// <summary>
|
|||
/// The properties of each dimensions of the observation.
|
|||
/// The length of the array must be equal to the rank of the observation tensor.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// It is generally recommended to use default values provided by helper functions,
|
|||
/// as not all combinations of DimensionProperty may be supported by the trainer.
|
|||
/// </remarks>
|
|||
public InplaceArray<DimensionProperty> DimensionProperties |
|||
{ |
|||
get => m_DimensionProperties; |
|||
} |
|||
|
|||
internal ObservationType m_ObservationType; |
|||
|
|||
/// <summary>
|
|||
/// The type of the observation, e.g. whether they are generic or
|
|||
/// help determine the goal for the Agent.
|
|||
/// </summary>
|
|||
public ObservationType ObservationType |
|||
{ |
|||
get => m_ObservationType; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// The number of dimensions of the observation.
|
|||
/// </summary>
|
|||
public int Rank |
|||
{ |
|||
get { return Shape.Length; } |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Construct an ObservationSpec for 1-D observations of the requested length.
|
|||
/// </summary>
|
|||
/// <param name="length"></param>
|
|||
/// <param name="obsType"></param>
|
|||
/// <returns></returns>
|
|||
public static ObservationSpec Vector(int length, ObservationType obsType = ObservationType.Default) |
|||
{ |
|||
return new ObservationSpec( |
|||
new InplaceArray<int>(length), |
|||
new InplaceArray<DimensionProperty>(DimensionProperty.None), |
|||
obsType |
|||
); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Construct an ObservationSpec for variable-length observations.
|
|||
/// </summary>
|
|||
/// <param name="obsSize"></param>
|
|||
/// <param name="maxNumObs"></param>
|
|||
/// <returns></returns>
|
|||
public static ObservationSpec VariableLength(int obsSize, int maxNumObs) |
|||
{ |
|||
var dimProps = new InplaceArray<DimensionProperty>( |
|||
DimensionProperty.VariableSize, |
|||
DimensionProperty.None |
|||
); |
|||
return new ObservationSpec( |
|||
new InplaceArray<int>(obsSize, maxNumObs), |
|||
dimProps |
|||
); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Construct an ObservationSpec for visual-like observations, e.g. observations
|
|||
/// with a height, width, and possible multiple channels.
|
|||
/// </summary>
|
|||
/// <param name="height"></param>
|
|||
/// <param name="width"></param>
|
|||
/// <param name="channels"></param>
|
|||
/// <param name="obsType"></param>
|
|||
/// <returns></returns>
|
|||
public static ObservationSpec Visual(int height, int width, int channels, ObservationType obsType = ObservationType.Default) |
|||
{ |
|||
var dimProps = new InplaceArray<DimensionProperty>( |
|||
DimensionProperty.TranslationalEquivariance, |
|||
DimensionProperty.TranslationalEquivariance, |
|||
DimensionProperty.None |
|||
); |
|||
return new ObservationSpec( |
|||
new InplaceArray<int>(height, width, channels), |
|||
dimProps, |
|||
obsType |
|||
); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Create a general ObservationSpec from the shape, dimension properties, and observation type.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// Note that not all combinations of DimensionProperty may be supported by the trainer.
|
|||
/// shape and dimensionProperties must have the same size.
|
|||
/// </remarks>
|
|||
/// <param name="shape"></param>
|
|||
/// <param name="dimensionProperties"></param>
|
|||
/// <param name="observationType"></param>
|
|||
/// <exception cref="UnityAgentsException"></exception>
|
|||
public ObservationSpec( |
|||
InplaceArray<int> shape, |
|||
InplaceArray<DimensionProperty> dimensionProperties, |
|||
ObservationType observationType = ObservationType.Default |
|||
) |
|||
{ |
|||
if (shape.Length != dimensionProperties.Length) |
|||
{ |
|||
throw new UnityAgentsException("shape and dimensionProperties must have the same length."); |
|||
} |
|||
m_Shape = shape; |
|||
m_DimensionProperties = dimensionProperties; |
|||
m_ObservationType = observationType; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: cc1734d60fd5485ead94247cb206aa35 |
|||
timeCreated: 1615412644 |
|
|||
using System; |
|||
using System.Collections; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents; |
|||
using UnityEngine; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
[TestFixture] |
|||
public class InplaceArrayTests |
|||
{ |
|||
class LengthCases : IEnumerable |
|||
{ |
|||
public IEnumerator GetEnumerator() |
|||
{ |
|||
yield return 1; |
|||
yield return 2; |
|||
yield return 3; |
|||
yield return 4; |
|||
} |
|||
} |
|||
|
|||
private InplaceArray<int> GetTestArray(int length) |
|||
{ |
|||
switch (length) |
|||
{ |
|||
case 1: |
|||
return new InplaceArray<int>(11); |
|||
case 2: |
|||
return new InplaceArray<int>(11, 22); |
|||
case 3: |
|||
return new InplaceArray<int>(11, 22, 33); |
|||
case 4: |
|||
return new InplaceArray<int>(11, 22, 33, 44); |
|||
default: |
|||
throw new ArgumentException("bad test!"); |
|||
} |
|||
} |
|||
|
|||
private InplaceArray<int> GetZeroArray(int length) |
|||
{ |
|||
switch (length) |
|||
{ |
|||
case 1: |
|||
return new InplaceArray<int>(0); |
|||
case 2: |
|||
return new InplaceArray<int>(0, 0); |
|||
case 3: |
|||
return new InplaceArray<int>(0, 0, 0); |
|||
case 4: |
|||
return new InplaceArray<int>(0, 0, 0, 0); |
|||
default: |
|||
throw new ArgumentException("bad test!"); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestInplaceArrayCtor() |
|||
{ |
|||
var a1 = new InplaceArray<int>(11); |
|||
Assert.AreEqual(1, a1.Length); |
|||
Assert.AreEqual(11, a1[0]); |
|||
|
|||
var a2 = new InplaceArray<int>(11, 22); |
|||
Assert.AreEqual(2, a2.Length); |
|||
Assert.AreEqual(11, a2[0]); |
|||
Assert.AreEqual(22, a2[1]); |
|||
|
|||
var a3 = new InplaceArray<int>(11, 22, 33); |
|||
Assert.AreEqual(3, a3.Length); |
|||
Assert.AreEqual(11, a3[0]); |
|||
Assert.AreEqual(22, a3[1]); |
|||
Assert.AreEqual(33, a3[2]); |
|||
|
|||
var a4 = new InplaceArray<int>(11, 22, 33, 44); |
|||
Assert.AreEqual(4, a4.Length); |
|||
Assert.AreEqual(11, a4[0]); |
|||
Assert.AreEqual(22, a4[1]); |
|||
Assert.AreEqual(33, a4[2]); |
|||
Assert.AreEqual(44, a4[3]); |
|||
} |
|||
|
|||
[TestCaseSource(typeof(LengthCases))] |
|||
public void TestInplaceGetSet(int length) |
|||
{ |
|||
var original = GetTestArray(length); |
|||
|
|||
for (var i = 0; i < original.Length; i++) |
|||
{ |
|||
var modified = original; |
|||
modified[i] = 0; |
|||
for (var j = 0; j < original.Length; j++) |
|||
{ |
|||
if (i == j) |
|||
{ |
|||
// This is the one we overwrote
|
|||
Assert.AreEqual(0, modified[j]); |
|||
} |
|||
else |
|||
{ |
|||
// Other elements should be unchanged
|
|||
Assert.AreEqual(original[j], modified[j]); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
[TestCaseSource(typeof(LengthCases))] |
|||
public void TestInvalidAccess(int length) |
|||
{ |
|||
var tmp = 0; |
|||
var a = GetTestArray(length); |
|||
// get
|
|||
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[-1]; }); |
|||
Assert.Throws<IndexOutOfRangeException>(() => { tmp += a[length]; }); |
|||
|
|||
// set
|
|||
Assert.Throws<IndexOutOfRangeException>(() => { a[-1] = 0; }); |
|||
Assert.Throws<IndexOutOfRangeException>(() => { a[length] = 0; }); |
|||
|
|||
// Make sure temp is used
|
|||
Assert.AreEqual(0, tmp); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestOperatorEqualsDifferentLengths() |
|||
{ |
|||
// Check arrays of different length are never equal (even if they have 0s in all elements)
|
|||
for (var l1 = 1; l1 <= 4; l1++) |
|||
{ |
|||
var a1 = GetZeroArray(l1); |
|||
for (var l2 = 1; l2 <= 4; l2++) |
|||
{ |
|||
var a2 = GetZeroArray(l2); |
|||
if (l1 == l2) |
|||
{ |
|||
Assert.AreEqual(a1, a2); |
|||
Assert.IsTrue(a1 == a2); |
|||
} |
|||
else |
|||
{ |
|||
Assert.AreNotEqual(a1, a2); |
|||
Assert.IsTrue(a1 != a2); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
[TestCaseSource(typeof(LengthCases))] |
|||
public void TestOperatorEquals(int length) |
|||
{ |
|||
for (var index = 0; index < length; index++) |
|||
{ |
|||
var a1 = GetTestArray(length); |
|||
var a2 = GetTestArray(length); |
|||
Assert.AreEqual(a1, a2); |
|||
Assert.IsTrue(a1 == a2); |
|||
|
|||
a1[index] = 42; |
|||
Assert.AreNotEqual(a1, a2); |
|||
Assert.IsTrue(a1 != a2); |
|||
|
|||
a2[index] = 42; |
|||
Assert.AreEqual(a1, a2); |
|||
Assert.IsTrue(a1 == a2); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestToString() |
|||
{ |
|||
Assert.AreEqual("[1]", new InplaceArray<int>(1).ToString()); |
|||
Assert.AreEqual("[1, 2]", new InplaceArray<int>(1, 2).ToString()); |
|||
Assert.AreEqual("[1, 2, 3]", new InplaceArray<int>(1, 2, 3).ToString()); |
|||
Assert.AreEqual("[1, 2, 3, 4]", new InplaceArray<int>(1, 2, 3, 4).ToString()); |
|||
} |
|||
|
|||
[TestCaseSource(typeof(LengthCases))] |
|||
public void TestFromList(int length) |
|||
{ |
|||
var intArray = new int[length]; |
|||
for (var i = 0; i < length; i++) |
|||
{ |
|||
intArray[i] = (i + 1) * 11; // 11, 22, etc.
|
|||
} |
|||
|
|||
var converted = InplaceArray<int>.FromList(intArray); |
|||
Assert.AreEqual(GetTestArray(length), converted); |
|||
} |
|||
} |
|||
} |
|
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
[TestFixture] |
|||
public class ObservationSpecTests |
|||
{ |
|||
[Test] |
|||
public void TestVectorObsSpec() |
|||
{ |
|||
var obsSpec = ObservationSpec.Vector(5); |
|||
Assert.AreEqual(1, obsSpec.Rank); |
|||
|
|||
var shape = obsSpec.Shape; |
|||
Assert.AreEqual(1, shape.Length); |
|||
Assert.AreEqual(5, shape[0]); |
|||
|
|||
var dimensionProps = obsSpec.DimensionProperties; |
|||
Assert.AreEqual(1, dimensionProps.Length); |
|||
Assert.AreEqual(DimensionProperty.None, dimensionProps[0]); |
|||
|
|||
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestVariableLengthObsSpec() |
|||
{ |
|||
var obsSpec = ObservationSpec.VariableLength(5, 6); |
|||
Assert.AreEqual(2, obsSpec.Rank); |
|||
|
|||
var shape = obsSpec.Shape; |
|||
Assert.AreEqual(2, shape.Length); |
|||
Assert.AreEqual(5, shape[0]); |
|||
Assert.AreEqual(6, shape[1]); |
|||
|
|||
var dimensionProps = obsSpec.DimensionProperties; |
|||
Assert.AreEqual(2, dimensionProps.Length); |
|||
Assert.AreEqual(DimensionProperty.VariableSize, dimensionProps[0]); |
|||
Assert.AreEqual(DimensionProperty.None, dimensionProps[1]); |
|||
|
|||
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestVisualObsSpec() |
|||
{ |
|||
var obsSpec = ObservationSpec.Visual(5, 6, 7); |
|||
Assert.AreEqual(3, obsSpec.Rank); |
|||
|
|||
var shape = obsSpec.Shape; |
|||
Assert.AreEqual(3, shape.Length); |
|||
Assert.AreEqual(5, shape[0]); |
|||
Assert.AreEqual(6, shape[1]); |
|||
Assert.AreEqual(7, shape[2]); |
|||
|
|||
var dimensionProps = obsSpec.DimensionProperties; |
|||
Assert.AreEqual(3, dimensionProps.Length); |
|||
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[0]); |
|||
Assert.AreEqual(DimensionProperty.TranslationalEquivariance, dimensionProps[1]); |
|||
Assert.AreEqual(DimensionProperty.None, dimensionProps[2]); |
|||
|
|||
Assert.AreEqual(ObservationType.Default, obsSpec.ObservationType); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestMismatchShapeDimensionPropThrows() |
|||
{ |
|||
var shape = new InplaceArray<int>(1, 2); |
|||
var dimProps = new InplaceArray<DimensionProperty>(DimensionProperty.TranslationalEquivariance); |
|||
Assert.Throws<UnityAgentsException>(() => |
|||
{ |
|||
new ObservationSpec(shape, dimProps); |
|||
}); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 27ff1979bd5e4b8ebeb4d98f414a5090 |
|||
timeCreated: 1615863866 |
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// The ObservationType enum of the Sensor.
|
|||
/// </summary>
|
|||
internal enum ObservationType |
|||
{ |
|||
// Collected observations are generic.
|
|||
Default = 0, |
|||
// Collected observations contain goal information.
|
|||
Goal = 1, |
|||
// Collected observations contain reward information.
|
|||
Reward = 2, |
|||
// Collected observations are messages from other agents.
|
|||
Message = 3, |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Sensor interface for sensors with variable types.
|
|||
/// </summary>
|
|||
internal interface ITypedSensor |
|||
{ |
|||
/// <summary>
|
|||
/// Returns the ObservationType enum corresponding to the type of the sensor.
|
|||
/// </summary>
|
|||
/// <returns>The ObservationType enum</returns>
|
|||
ObservationType GetObservationType(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 3751edac8122c411dbaef8f1b7043b82 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// The Dimension property flags of the observations
|
|||
/// </summary>
|
|||
[System.Flags] |
|||
public enum DimensionProperty |
|||
{ |
|||
/// <summary>
|
|||
/// No properties specified.
|
|||
/// </summary>
|
|||
Unspecified = 0, |
|||
|
|||
/// <summary>
|
|||
/// No Property of the observation in that dimension. Observation can be processed with
|
|||
/// fully connected networks.
|
|||
/// </summary>
|
|||
None = 1, |
|||
|
|||
/// <summary>
|
|||
/// Means it is suitable to do a convolution in this dimension.
|
|||
/// </summary>
|
|||
TranslationalEquivariance = 2, |
|||
|
|||
/// <summary>
|
|||
/// Means that there can be a variable number of observations in this dimension.
|
|||
/// The observations are unordered.
|
|||
/// </summary>
|
|||
VariableSize = 4, |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Sensor interface for sensors with special dimension properties.
|
|||
/// </summary>
|
|||
internal interface IDimensionPropertiesSensor |
|||
{ |
|||
/// <summary>
|
|||
/// Returns the array containing the properties of each dimensions of the
|
|||
/// observation. The length of the array must be equal to the rank of the
|
|||
/// observation tensor.
|
|||
/// </summary>
|
|||
/// <returns>The array of DimensionProperty</returns>
|
|||
DimensionProperty[] GetDimensionProperties(); |
|||
} |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue