浏览代码

spec and inplacearray cleanup

/v2-staging-rebase
Chris Elion 3 年前
当前提交
1c508989
共有 24 个文件被更改,包括 103 次插入128 次删除
  1. 2
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
  2. 2
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs
  3. 4
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  4. 4
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  5. 4
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  6. 34
      com.unity.ml-agents/Runtime/InplaceArray.cs
  7. 2
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  8. 2
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  9. 2
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  10. 75
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  11. 27
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  12. 2
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  13. 2
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  14. 2
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  15. 11
      com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs
  16. 6
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  17. 2
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  18. 7
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  19. 2
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  20. 2
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  21. 4
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  22. 2
      com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs
  23. 21
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  24. 10
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs

2
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs


/// <inheritdoc/>
public override ObservationSpec GetObservationSpec()
{
return ObservationSpec.FromShape(BasicController.k_Extents);
return ObservationSpec.Vector(BasicController.k_Extents);
}
/// <inheritdoc/>

2
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensor.cs


var width = texture.width;
var height = texture.height;
m_Name = name;
m_ObservationSpec = ObservationSpec.FromShape(height, width, 3);
m_ObservationSpec = ObservationSpec.Visual(height, width, 3);
m_CompressionType = compressionType;
}

4
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


m_ObservationType = obsType;
m_ObservationSpec = obsType == Match3ObservationType.Vector
? ObservationSpec.FromShape(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.FromShape(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);

4
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;
m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell);
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
compressedImgs = new List<byte[]>();
byteSizesBytesList = new List<byte[]>();

var shape = m_ObservationSpec.Shape;
if (shape[0] != GridNumSideX || shape[1] != GridNumSideZ || shape[2] != ObservationPerCell)
{
m_ObservationSpec = ObservationSpec.FromShape(GridNumSideX, GridNumSideZ, ObservationPerCell);
m_ObservationSpec = ObservationSpec.Visual(GridNumSideX, GridNumSideZ, ObservationPerCell);
}
return m_ObservationSpec;
}

4
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations);
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#if UNITY_2020_1_OR_NEWER

}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_ObservationSpec = ObservationSpec.FromShape(numTransformObservations + numJointExtractorObservations);
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#endif

34
com.unity.ml-agents/Runtime/InplaceArray.cs


throw new ArgumentOutOfRangeException();
}
}
public static bool operator ==(InplaceArray<T> lhs, InplaceArray<T> rhs)
{
if (lhs.Length != rhs.Length)
{
return false;
}
for (var i = 0; i < lhs.Length; i++)
{
// See https://stackoverflow.com/a/390974/224264
if (!EqualityComparer<T>.Default.Equals(lhs[i], rhs[i]))
{
return false;
}
}
return true;
}
public static bool operator !=(InplaceArray<T> lhs, InplaceArray<T> rhs) => !(lhs == rhs);
public override bool Equals(object other) => other is InplaceArray<T> other1 && this.Equals(other1);
public bool Equals(InplaceArray<T> other)
{
return this == other;
}
public override int GetHashCode()
{
// TODO need to switch on length?
return Tuple.Create(m_elem0, m_elem1, m_elem2, m_elem3, Length).GetHashCode();
}
}
}

2
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


m_ObsSize = obsSize;
m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs];
m_CurrentNumObservables = 0;
m_ObservationSpec = ObservationSpec.FromShape(m_MaxNumObs, m_ObsSize);
m_ObservationSpec = ObservationSpec.VariableSize(m_MaxNumObs, m_ObsSize);
}
/// <inheritdoc/>

2
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


m_Grayscale = grayscale;
m_Name = name;
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.FromShape(height, width, channels);
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_CompressionType = compression;
}

2
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


{
var obsSpec = sensor.GetObservationSpec();
var count = 1;
for (var i = 0; i < obsSpec.Shape.Length; i++)
for (var i = 0; i < obsSpec.NumDimensions; i++)
{
count *= obsSpec.Shape[i];
}

75
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs


public InplaceArray<int> Shape;
public InplaceArray<DimensionProperty> DimensionProperties;
public int Dimensions
public int NumDimensions
// TODO RENAME?
public static ObservationSpec FromShape(int length)
public static ObservationSpec Vector(int length)
return new ObservationSpec
{
ObservationType = ObservationType.Default,
Shape = shape,
DimensionProperties = dimProps
};
return new ObservationSpec(shape, dimProps);
public static ObservationSpec FromShape(int obsSize, int maxNumObs)
public static ObservationSpec VariableSize(int obsSize, int maxNumObs)
return new ObservationSpec
{
ObservationType = ObservationType.Default,
Shape = shape,
DimensionProperties = dimProps
};
return new ObservationSpec(shape, dimProps);
public static ObservationSpec FromShape(int height, int width, int channels)
public static ObservationSpec Visual(int height, int width, int channels)
return new ObservationSpec
return new ObservationSpec(shape, dimProps);
}
internal ObservationSpec(
InplaceArray<int> shape,
InplaceArray<DimensionProperty> dimensionProperties,
ObservationType observationType = ObservationType.Default
)
{
if (shape.Length != dimensionProperties.Length)
ObservationType = ObservationType.Default,
Shape = shape,
DimensionProperties = dimProps
};
throw new UnityAgentsException("shape and dimensionProperties must have the same length.");
}
Shape = shape;
DimensionProperties = dimensionProperties;
ObservationType = observationType;
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// <summary>
/// Information about a single dimension. Future per-dimension properties can go here.
/// This is nicer because it ensures the shape and dimension properties that the same size
/// </summary>
public struct DimensionInfo
{
public int Rank;
public DimensionProperty DimensionProperty;
}
public struct ObservationSpecAlternativeOne
{
public ObservationType ObservationType;
public DimensionInfo[] DimensionInfos;
// Similar ObservationSpec.FromShape() as above
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// <summary>
/// Uses Barracuda's TensorShape struct instead of an int[] for the shape.
/// This doesn't fully avoid allocations because of DimensionProperty, so we'd need more supporting code.
/// I don't like explicitly depending on Barracuda in one of our central interfaces, but listing as an alternative.
/// </summary>
public struct ObservationSpecAlternativeTwo
{
public ObservationType ObservationType;
public TensorShape Shape;
public DimensionProperty[] DimensionProperties;
}
}

27
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="shape">Shape of the observations to be written.</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, int[] shape, int offset)
{
m_Data = data;
m_Offset = offset;
m_Proxy = null;
m_Batch = 0;
if (shape.Length == 1)
{
m_TensorShape = new TensorShape(m_Batch, shape[0]);
}
else if (shape.Length == 2)
{
m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] });
}
else
{
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]);
}
}
/// <summary>
/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="observationSpec">ObservationSpec of the observation to be written</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset)

2
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


void SetNumObservations(int numObservations)
{
m_ObservationSpec = ObservationSpec.FromShape(numObservations);
m_ObservationSpec = ObservationSpec.Vector(numObservations);
m_Observations = new float[numObservations];
}

2
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


m_PropertyInfo = reflectionSensorInfo.PropertyInfo;
m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_ObservationSpec = ObservationSpec.FromShape(size);
m_ObservationSpec = ObservationSpec.Vector(size);
m_NumFloats = size;
}

2
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


var height = renderTexture != null ? renderTexture.height : 0;
m_Grayscale = grayscale;
m_Name = name;
m_ObservationSpec = ObservationSpec.FromShape(height, width, grayscale ? 1 : 3);
m_ObservationSpec = ObservationSpec.Visual(height, width, grayscale ? 1 : 3);
m_CompressionType = compressionType;
}

11
com.unity.ml-agents/Runtime/Sensors/SensorShapeValidator.cs


{
var cachedSpec = m_SensorShapes[i];
var sensorSpec = sensors[i].GetObservationSpec();
Debug.Assert(cachedSpec.Shape.Length == sensorSpec.Shape.Length, "Sensor dimensions must match.");
for (var j = 0; j < Mathf.Min(cachedSpec.Shape.Length, sensorSpec.Shape.Length); j++)
{
Debug.Assert(cachedSpec.Shape[j] == sensorSpec.Shape[j], "Sensor sizes must match.");
}
Debug.AssertFormat(
cachedSpec.Shape == sensorSpec.Shape,
"Sensor shapes must match. {0} != {1}",
cachedSpec.Shape,
sensorSpec.Shape
);
}
}
}

6
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


m_UnstackedObservationSize = wrapped.ObservationSize();
// TODO support arbitrary stacking dimension
m_ObservationSpec.Shape[m_ObservationSpec.Shape.Length - 1] *= numStackedObservations;
m_ObservationSpec.Shape[m_ObservationSpec.NumDimensions - 1] *= numStackedObservations;
// Initialize uncompressed buffer anyway in case python trainer does not
// support the compression mapping and has to fall back to uncompressed obs.

m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
}
if (m_WrappedSpec.Shape.Length != 1)
if (m_WrappedSpec.NumDimensions != 1)
{
var wrappedShape = m_WrappedSpec.Shape;
m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]);

// Now write the saved observations (oldest first)
var numWritten = 0;
if (m_WrappedSpec.Shape.Length == 1)
if (m_WrappedSpec.NumDimensions == 1)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{

2
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


m_Observations = new List<float>(observationSize);
m_Name = name;
m_ObservationSpec = ObservationSpec.FromShape(observationSize);
m_ObservationSpec = ObservationSpec.Vector(observationSize);
}
/// <inheritdoc/>

7
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var inplaceShape = InplaceArray<int>.FromList(shape);
dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0]);
dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]);
dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0], shape[1], shape[2]);
dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]);
}
else
{

obsWriter.SetTarget(new float[128], shape, 0);
obsWriter.SetTarget(new float[128], inplaceShape, 0);
var caps = new UnityRLCapabilities
{

2
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


public ObservationSpec GetObservationSpec()
{
return ObservationSpec.FromShape(0);
return ObservationSpec.Vector(0);
}
public int Write(ObservationWriter writer)

2
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


public ObservationSpec GetObservationSpec()
{
return ObservationSpec.FromShape(m_Height, m_Width, m_Channels);
return ObservationSpec.Visual(m_Height, m_Width, m_Channels);
}
public int Write(ObservationWriter writer)

4
com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs


Height = height;
m_Name = name;
m_ObservationSpec = ObservationSpec.FromShape(height, width, 1);
m_ObservationSpec = ObservationSpec.Visual(height, width, 1);
floatData = new float[Height, Width];
}

Height = floatData.GetLength(0);
Width = floatData.GetLength(1);
m_Name = name;
m_ObservationSpec = ObservationSpec.FromShape(Height, Width, 1);
m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1);
}
public string GetName()

2
com.unity.ml-agents/Tests/Editor/Sensor/ObservationWriterTests.cs


{
ObservationWriter writer = new ObservationWriter();
var buffer = new[] { 0f, 0f, 0f };
var shape = new[] { 3 };
var shape = new InplaceArray<int>(3);
writer.SetTarget(buffer, shape, 0);
// Elementwise writes

21
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


using System.Collections.Generic;
using System.Text.RegularExpressions;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;

public DummySensor(int dim1)
{
m_ObservationSpec = ObservationSpec.FromShape(dim1);
m_ObservationSpec = ObservationSpec.Vector(dim1);
m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2);
m_ObservationSpec = ObservationSpec.VariableSize(dim1, dim2);
m_ObservationSpec = ObservationSpec.FromShape(dim1, dim2, dim3);
m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3);
}
public string GetName()

validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) };
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}

validator.ValidateSensors(sensorList1);
var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) };
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}

var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) };
LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2");
LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList2);
// Add the sensors in the other order

LogAssert.Expect(LogType.Assert, "Sensor dimensions must match.");
LogAssert.Expect(LogType.Assert, "Sensor sizes must match.");
LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*"));
validator.ValidateSensors(sensorList1);
}
}

10
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


// Test mapping with number of layers not being multiple of 3
var dummySensor = new Dummy3DSensor();
dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

paddedDummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
paddedDummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 };
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });

public void Test3DStacking()
{
var wrapped = new Dummy3DSensor();
wrapped.ObservationSpec = ObservationSpec.FromShape(2, 1, 2);
wrapped.ObservationSpec = ObservationSpec.Visual(2, 1, 2);
var sensor = new StackingSensor(wrapped, 2);
// Check the stacking is on the last dimension

public void TestStackedGetCompressedObservation()
{
var wrapped = new Dummy3DSensor();
wrapped.ObservationSpec = ObservationSpec.FromShape(1, 1, 3);
wrapped.ObservationSpec = ObservationSpec.Visual(1, 1, 3);
var sensor = new StackingSensor(wrapped, 2);
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } };

public void TestStackingSensorBuiltInSensorType()
{
var dummySensor = new Dummy3DSensor();
dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
dummySensor.ObservationSpec = ObservationSpec.Visual(2, 2, 4);
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
var stackedDummySensor = new StackingSensor(dummySensor, 2);
Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown);

正在加载...
取消
保存