浏览代码

use ObservationSpec everywhere

/v2-staging-rebase
Chris Elion 3 年前
当前提交
8f16fab3
共有 29 个文件被更改,包括 146 次插入125 次删除
  1. 4
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
  2. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  3. 8
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
  4. 12
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  5. 29
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  6. 10
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  7. 12
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  8. 8
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs
  9. 6
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs
  10. 2
      com.unity.ml-agents/Runtime/Analytics/Events.cs
  11. 4
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  12. 20
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  13. 2
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  14. 2
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  15. 4
      com.unity.ml-agents/Runtime/SensorHelper.cs
  16. 7
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  17. 12
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  18. 12
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  19. 16
      com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
  20. 8
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  21. 4
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  22. 6
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  23. 2
      com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs
  24. 2
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  25. 13
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  26. 20
      com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs
  27. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
  28. 12
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  29. 30
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs

4
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs


}
/// <inheritdoc/>
public override int[] GetObservationShape()
public override ObservationSpec GetObservationSpec()
return new[] { BasicController.k_Extents };
return ObservationSpec.FromShape(BasicController.k_Extents);
}
/// <inheritdoc/>

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


public abstract void WriteObservation(float[] output);
/// <inheritdoc/>
public abstract int[] GetObservationShape();
public abstract ObservationSpec GetObservationSpec();
/// <inheritdoc/>
public abstract string GetName();

8
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs


{
Texture2D m_Texture;
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
SensorCompressionType m_CompressionType;
/// <summary>

var width = texture.width;
var height = texture.height;
m_Name = name;
m_Shape = new[] { height, width, 3 };
m_ObservationSpec = ObservationSpec.FromShape(height, width, 3);
m_CompressionType = compressionType;
}

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
private int[] m_SparseChannelMapping;
private string m_Name;

m_NumSpecialTypes = board.NumSpecialTypes;
m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.FromShape(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.FromShape(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);

}
/// <inheritdoc/>
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

29
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


protected bool Initialized = false;
/// <summary>
/// Array holding the dimensions of the resulting tensor
/// Cached ObservationSpec
private int[] m_Shape;
private ObservationSpec m_ObservationSpec;
//
// Debug Parameters

// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell);
compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();

CellActivity[i] = DebugDefaultColor;
}
}
}
/// <summary>Gets the shape of the grid observation</summary>
/// <returns>integer array shape of the grid observation</returns>
public int[] GetFloatObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}
/// <inheritdoc/>

/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
public ObservationSpec GetObservationSpec()
{
// Lazy update
var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}
/// <inheritdoc/>
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
return m_ObservationSpec.Shape;
}
/// <inheritdoc/>

10
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


/// </summary>
public class PhysicsBodySensor : ISensor, IBuiltInSensor
{
int[] m_Shape;
ObservationSpec m_ObservationSpec;
string m_SensorName;
PoseExtractor m_PoseExtractor;

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations);
}
#if UNITY_2020_1_OR_NEWER

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

8
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

gridSensor.Start();
int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}

6
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs


gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
[Test]

gridSensor.Start();
int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetFloatObservationShape());
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
}
}
}

2
com.unity.ml-agents/Runtime/Analytics/Events.cs


public static EventObservationSpec FromSensor(ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties();
var dimInfos = new EventObservationDimensionInfo[shape.Length];
for (var i = 0; i < shape.Length; i++)

4
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


/// <returns></returns>
public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
ObservationProto observationProto = null;
var compressionType = sensor.GetCompressionType();
// Check capabilities if we need to concatenate PNGs

floatDataProto.Data.Add(0.0f);
}
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0);
sensor.Write(observationWriter);
observationProto = new ObservationProto

20
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sensor = sensors[sensorIndex];
if (sensor.GetObservationShape().Length == 3)
if (sensor.GetObservationSpec().Shape.Length == 3)
{
if (!tensorsNames.Contains(
TensorNames.GetVisualObservationName(visObsIndex)))

}
visObsIndex++;
}
if (sensor.GetObservationShape().Length == 2)
if (sensor.GetObservationSpec().Shape.Length == 2)
{
if (!tensorsNames.Contains(
TensorNames.GetObservationName(sensorIndex)))

static string CheckVisualObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];

static string CheckRankTwoObsShape(
TensorProxy tensorProxy, ISensor sensor)
{
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;

for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationShape().Length == 3)
if (sens.GetObservationSpec().Shape.Length == 3)
{
tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] =

if (sens.GetObservationShape().Length == 2)
if (sens.GetObservationSpec().Shape.Length == 2)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);

var totalVectorSensorSize = 0;
foreach (var sens in sensors)
{
if ((sens.GetObservationShape().Length == 1))
if ((sens.GetObservationSpec().Shape.Length == 1))
totalVectorSensorSize += sens.GetObservationShape()[0];
totalVectorSensorSize += sens.GetObservationSpec().Shape[0];
}
}

foreach (var sensorComp in sensors)
{
if (sensorComp.GetObservationShape().Length == 1)
if (sensorComp.GetObservationSpec().Shape.Length == 1)
var vecSize = sensorComp.GetObservationShape()[0];
var vecSize = sensorComp.GetObservationSpec().Shape[0];
if (sensorSizes.Length == 0)
{
sensorSizes = $"[{vecSize}";

2
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];
var shape = sensor.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
var rank = shape.Length;
ObservationGenerator obsGen = null;
string obsGenName = null;

2
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


{
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationShape(), 0);
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationSpec(), 0);
sensor.Write(m_ObservationWriter);
}
else

4
com.unity.ml-agents/Runtime/SensorHelper.cs


}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)

7
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


private int m_ObsSize;
float[] m_ObservationBuffer;
int m_CurrentNumObservables;
ObservationSpec m_ObservationSpec;
static DimensionProperty[] s_DimensionProperties = new DimensionProperty[]{
DimensionProperty.VariableSize,
DimensionProperty.None

m_ObsSize = obsSize;
m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs];
m_CurrentNumObservables = 0;
m_ObservationSpec = ObservationSpec.FromShape(m_MaxNumObs, m_ObsSize);
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new int[] { m_MaxNumObs, m_ObsSize };
return m_ObservationSpec;
}
/// <inheritdoc/>

12
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


}
/// <summary>
/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="observationSpec">ObservationSpec of the observation to be written</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset)
{
// TODO remove int[] version
SetTarget(data, observationSpec.Shape, offset);
}
/// <summary>
/// Set the writer to write to a TensorProxy at the given batch and channel offset.
/// </summary>
/// <param name="tensorProxy">Tensor proxy that will be written to.</param>

12
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


// Cached sensor names and shapes.
string m_SensorName;
int[] m_Shape;
ObservationSpec m_ObservationSpec;
int m_NumFloats;
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size)
{

m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_Shape = new[] { size };
m_ObservationSpec = ObservationSpec.FromShape(size);
m_NumFloats = size;
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
/// <inheritdoc/>

return m_Shape[0];
return m_NumFloats;
}
internal abstract void WriteReflectedField(ObservationWriter writer);

16
com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs


{
internal class SensorShapeValidator
{
List<int[]> m_SensorShapes;
List<ObservationSpec> m_SensorShapes;
/// <summary>
/// Check that the List Sensors are the same shape as the previous ones.

{
if (m_SensorShapes == null)
{
m_SensorShapes = new List<int[]>(sensors.Count);
m_SensorShapes = new List<ObservationSpec>(sensors.Count);
m_SensorShapes.Add(sensor.GetObservationShape());
m_SensorShapes.Add(sensor.GetObservationSpec());
}
}
else

);
for (var i = 0; i < Mathf.Min(m_SensorShapes.Count, sensors.Count); i++)
{
var cachedShape = m_SensorShapes[i];
var sensorShape = sensors[i].GetObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < Mathf.Min(cachedShape.Length, sensorShape.Length); j++)
var cachedSpec = m_SensorShapes[i];
var sensorSpec = sensors[i].GetObservationSpec();
Debug.Assert(cachedSpec.Shape.Length == sensorSpec.Shape.Length, "Sensor dimensions must match.");
for (var j = 0; j < Mathf.Min(cachedSpec.Shape.Length, sensorSpec.Shape.Length); j++)
Debug.Assert(cachedShape[j] == sensorShape[j], "Sensor sizes must match.");
Debug.Assert(cachedSpec.Shape[j] == sensorSpec.Shape[j], "Sensor sizes must match.");
}
}
}

8
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


class DummySensor : ISensor
{
public int[] Shape;
public ObservationSpec ObservationSpec;
public SensorCompressionType CompressionType;
internal DummySensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
}
public int Write(ObservationWriter writer)

var dummySensor = new DummySensor();
var obsWriter = new ObservationWriter();
dummySensor.Shape = shape;
dummySensor.ObservationSpec = ObservationSpec.FromShape(shape);
dummySensor.CompressionType = compressionType;
obsWriter.SetTarget(new float[128], shape, 0);

4
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


sensorName = n;
}
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new[] { 0 };
return ObservationSpec.FromShape(0);
}
public int Write(ObservationWriter writer)

6
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


public override int[] GetObservationShape()
{
return Sensor.GetObservationShape();
return Sensor.GetObservationSpec().Shape;
}
}
public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor

m_Name = name;
}
public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return new[] { m_Height, m_Width, m_Channels };
return ObservationSpec.FromShape(m_Height, m_Width, m_Channels);
}
public int Write(ObservationWriter writer)

2
com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs


{
var bufferSensor = new BufferSensor(20, 4, "testName");
var shape = bufferSensor.GetObservationShape();
var shape = bufferSensor.GetObservationSpec().Shape;
var dimProp = bufferSensor.GetDimensionProperties();
Assert.AreEqual(shape[0], 20);
Assert.AreEqual(shape[1], 4);

2
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
var sensor = cameraComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(CameraSensor), sensor.GetType());
}
}

13
com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs


public int Width { get; }
public int Height { get; }
string m_Name;
int[] m_Shape;
private ObservationSpec m_ObservationSpec;
public float[,] floatData;
public Float2DSensor(int width, int height, string name)

m_Name = name;
m_Shape = new[] { height, width, 1 };
m_ObservationSpec = ObservationSpec.FromShape(height, width, 1);
floatData = new float[Height, Width];
}

Height = floatData.GetLength(0);
Width = floatData.GetLength(1);
m_Name = name;
m_Shape = new[] { Height, Width, 1 };
m_ObservationSpec = ObservationSpec.FromShape(Height, Width, 1);
}
public string GetName()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
public byte[] GetCompressedObservation()

var output = new float[12];
var writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
sensor.Write(writer);
for (var i = 0; i < 9; i++)
{

20
com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs


var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
writer.SetTarget(outputBuffer, sensor.GetObservationSpec(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

2
com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs


Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
var sensor = renderTexComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType());
}
}

12
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


public class DummySensor : ISensor
{
string m_Name = "DummySensor";
int[] m_Shape;
ObservationSpec m_ObservationSpec;
m_Shape = new[] { dim1 };
m_ObservationSpec = ObservationSpec.FromShape(dim1);
m_Shape = new[] { dim1, dim2, };
m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2);
m_Shape = new[] { dim1, dim2, dim3 };
m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2, dim3);
}
public string GetName()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return m_Shape;
return m_ObservationSpec;
}
public byte[] GetCompressedObservation()

30
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


ISensor wrapped = new VectorSensor(4);
ISensor sensor = new StackingSensor(wrapped, 4);
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
Assert.AreEqual(sensor.GetObservationShape(), new[] { 16 });
Assert.AreEqual(sensor.GetObservationSpec().Shape, new[] { 16 });
}
[Test]

{
public SensorCompressionType CompressionType = SensorCompressionType.PNG;
public int[] Mapping;
public int[] Shape;
public ObservationSpec ObservationSpec;
public float[,,] CurrentObservation;
internal Dummy3DSensor()

public int[] GetObservationShape()
public ObservationSpec GetObservationSpec()
return Shape;
return ObservationSpec;
for (var h = 0; h < Shape[0]; h++)
for (var h = 0; h < ObservationSpec.Shape[0]; h++)
for (var w = 0; w < Shape[1]; w++)
for (var w = 0; w < ObservationSpec.Shape[1]; w++)
for (var c = 0; c < Shape[2]; c++)
for (var c = 0; c < ObservationSpec.Shape[2]; c++)
return Shape[0] * Shape[1] * Shape[2];
return ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2];
var flattenedObservation = new float[Shape[0] * Shape[1] * Shape[2]];
writer.SetTarget(flattenedObservation, Shape, 0);
var flattenedObservation = new float[ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]];
writer.SetTarget(flattenedObservation, ObservationSpec.Shape, 0);
Write(writer);
byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z);
return bytes;

// Test mapping with number of layers not being multiple of 3
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new[] { 2, 2, 4 };
dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

paddedDummySensor.Shape = new[] { 2, 2, 4 };
paddedDummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 };
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

public void Test3DStacking()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new[] { 2, 1, 2 };
wrapped.ObservationSpec = ObservationSpec.FromShape(2, 1, 2);
var sensor = new StackingSensor(wrapped, 2);
// Check the stacking is on the last dimension

public void TestStackedGetCompressedObservation()
{
var wrapped = new Dummy3DSensor();
wrapped.Shape = new[] { 1, 1, 3 };
wrapped.ObservationSpec = ObservationSpec.FromShape(1, 1, 3);
var sensor = new StackingSensor(wrapped, 2);
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } };

public void TestStackingSensorBuiltInSensorType()
{
var dummySensor = new Dummy3DSensor();
dummySensor.Shape = new[] { 2, 2, 4 };
dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown);

正在加载...
取消
保存