浏览代码

InplaceArray for shape

/v2-staging-rebase
Chris Elion 4 年前
当前提交
b0e1cfc9
共有 16 个文件被更改,包括 244 次插入57 次删除
  1. 3
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  2. 12
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  3. 7
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  4. 3
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  5. 4
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  6. 67
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  7. 27
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  8. 4
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  9. 14
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs
  10. 3
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  11. 2
      com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs
  12. 3
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  13. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs
  14. 2
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs
  15. 145
      com.unity.ml-agents/Runtime/InplaceArray.cs
  16. 3
      com.unity.ml-agents/Runtime/InplaceArray.cs.meta

3
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


/// <inheritdoc/>
public override int[] GetObservationShape()
{
return m_ObservationSpec.Shape;
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
/// <inheritdoc/>

12
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());

7
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


}
}
}
observationProto.Shape.AddRange(shape);
// Implement IEnumerable or IList?
for (var i = 0; i < shape.Length; i++)
{
observationProto.Shape.Add(shape[i]);
}
// Add the observation type, if any, to the observationProto
var typeSensor = sensor as ITypedSensor;

3
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


m_Height = height;
m_Grayscale = grayscale;
m_Name = name;
m_ObservationSpec = ObservationSpec.FromShape(GenerateShape(width, height, grayscale));
var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.FromShape(height, width, channels);
m_CompressionType = compression;
}

4
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


{
var obsSpec = sensor.GetObservationSpec();
var count = 1;
foreach (var dim in obsSpec.Shape)
for (var i = 0; i < obsSpec.Shape.Length; i++)
count *= dim;
count *= obsSpec.Shape[i];
}
return count;

67
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs


public struct ObservationSpec
{
public ObservationType ObservationType;
public int[] Shape;
public DimensionProperty[] DimensionProperties;
public InplaceArray<int> Shape;
public InplaceArray<DimensionProperty> DimensionProperties;
/// <summary>
/// Create an Observation spec with default DimensionProperties and ObservationType from the shape.
/// </summary>
/// <param name="shape"></param>
/// <returns></returns>
public static ObservationSpec FromShape(params int[] shape)
public int Dimensions
DimensionProperty[] dimProps = null;
if (shape.Length == 1)
{
dimProps = new[] { DimensionProperty.None };
}
else if (shape.Length == 2)
get { return Shape.Length; }
}
// TODO RENAME?
public static ObservationSpec FromShape(int length)
{
InplaceArray<int> shape = new InplaceArray<int>(length);
InplaceArray<DimensionProperty> dimProps = new InplaceArray<DimensionProperty>(DimensionProperty.None);
return new ObservationSpec
// NOTE: not sure if I like this - might leave Unspecified and make BufferSensor set it
dimProps = new[] { DimensionProperty.VariableSize, DimensionProperty.None };
}
else if (shape.Length == 3)
{
dimProps = new[]
{
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None
};
}
else
{
dimProps = new DimensionProperty[shape.Length];
for (var i = 0; i < dimProps.Length; i++)
{
dimProps[i] = DimensionProperty.Unspecified;
}
}
ObservationType = ObservationType.Default,
Shape = shape,
DimensionProperties = dimProps
};
}
public static ObservationSpec FromShape(int obsSize, int maxNumObs)
{
InplaceArray<int> shape = new InplaceArray<int>(obsSize, maxNumObs);
InplaceArray<DimensionProperty> dimProps = new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None);
return new ObservationSpec
{
ObservationType = ObservationType.Default,

}
public ObservationSpec Clone()
public static ObservationSpec FromShape(int height, int width, int channels)
InplaceArray<int> shape = new InplaceArray<int>(height, width, channels);
InplaceArray<DimensionProperty> dimProps = new InplaceArray<DimensionProperty>(
DimensionProperty.TranslationalEquivariance, DimensionProperty.TranslationalEquivariance, DimensionProperty.None
);
Shape = (int[])Shape.Clone(),
DimensionProperties = (DimensionProperty[])DimensionProperties.Clone(),
ObservationType = ObservationType
ObservationType = ObservationType.Default,
Shape = shape,
DimensionProperties = dimProps
};
}
}

27
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


}
/// <summary>
/// Set the writer to write to an IList at the given channelOffset.
/// </summary>
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="shape">Shape of the observations to be written.</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
internal void SetTarget(IList<float> data, InplaceArray<int> shape, int offset)
{
m_Data = data;
m_Offset = offset;
m_Proxy = null;
m_Batch = 0;
if (shape.Length == 1)
{
m_TensorShape = new TensorShape(m_Batch, shape[0]);
}
else if (shape.Length == 2)
{
m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] });
}
else
{
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]);
}
}
/// <summary>
/// Set the writer to write to a TensorProxy at the given batch and channel offset.
/// </summary>
/// <param name="tensorProxy">Tensor proxy that will be written to.</param>

4
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
m_WrappedSpec = wrapped.GetObservationSpec();
m_ObservationSpec = m_WrappedSpec.Clone();
m_ObservationSpec = m_WrappedSpec;
m_UnstackedObservationSize = wrapped.ObservationSize();

public int Write(ObservationWriter writer)
{
// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec.Shape, 0);
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedSpec, 0);
m_WrappedSensor.Write(m_LocalWriter);
// Now write the saved observations (oldest first)

14
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Analytics;

var dummySensor = new DummySensor();
var obsWriter = new ObservationWriter();
dummySensor.ObservationSpec = ObservationSpec.FromShape(shape);
if (shape.Length == 1)
{
dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0]);
}
else if (shape.Length == 3)
{
dummySensor.ObservationSpec = ObservationSpec.FromShape(shape[0], shape[1], shape[2]);
}
else
{
throw new ArgumentOutOfRangeException();
}
dummySensor.CompressionType = compressionType;
obsWriter.SetTarget(new float[128], shape, 0);

3
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


public override int[] GetObservationShape()
{
return Sensor.GetObservationSpec().Shape;
var shape = Sensor.GetObservationSpec().Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
}
public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor

2
com.unity.ml-agents/Tests/Editor/Sensor/BufferSensorTest.cs


var obsWriter = new ObservationWriter();
var obs = bufferSensor.GetObservationProto(obsWriter);
Assert.AreEqual(shape, obs.Shape);
Assert.AreEqual(shape, InplaceArray<int>.FromList(obs.Shape));
Assert.AreEqual(obs.DimensionProperties.Count, 2);
Assert.AreEqual((int)dimProp[0], obs.DimensionProperties[0]);
Assert.AreEqual((int)dimProp[1], obs.DimensionProperties[1]);

3
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
var sensor = cameraComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
var expectedShapeInplace = new InplaceArray<int>(height, width, grayscale ? 1 : 3);
Assert.AreEqual(expectedShapeInplace, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(CameraSensor), sensor.GetType());
}
}

2
com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs


Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
var sensor = renderTexComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType());
}
}

2
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


ISensor wrapped = new VectorSensor(4);
ISensor sensor = new StackingSensor(wrapped, 4);
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
Assert.AreEqual(sensor.GetObservationSpec().Shape, new[] { 16 });
Assert.AreEqual(sensor.GetObservationSpec().Shape, new InplaceArray<int>(16));
}
[Test]

145
com.unity.ml-agents/Runtime/InplaceArray.cs


using System;
using System.Collections.Generic;
using System.Linq.Expressions;
namespace Unity.MLAgents
{
public struct InplaceArray<T> where T : struct
{
private const int k_MaxLength = 4;
private int m_Length;
private T m_elem0;
private T m_elem1;
private T m_elem2;
private T m_elem3;
public InplaceArray(T elem0)
{
m_Length = 1;
m_elem0 = elem0;
m_elem1 = new T { };
m_elem2 = new T { };
m_elem3 = new T { };
}
public InplaceArray(T elem0, T elem1)
{
m_Length = 2;
m_elem0 = elem0;
m_elem1 = elem1;
m_elem2 = new T { };
m_elem3 = new T { };
}
public InplaceArray(T elem0, T elem1, T elem2)
{
m_Length = 3;
m_elem0 = elem0;
m_elem1 = elem1;
m_elem2 = elem2;
m_elem3 = new T { };
}
public InplaceArray(T elem0, T elem1, T elem2, T elem3)
{
m_Length = 4;
m_elem0 = elem0;
m_elem1 = elem1;
m_elem2 = elem2;
m_elem3 = elem3;
}
public static InplaceArray<T> FromList(IList<T> elems)
{
switch (elems.Count)
{
case 1:
return new InplaceArray<T>(elems[0]);
case 2:
return new InplaceArray<T>(elems[0], elems[1]);
case 3:
return new InplaceArray<T>(elems[0], elems[1], elems[2]);
case 4:
return new InplaceArray<T>(elems[0], elems[1], elems[2], elems[3]);
default:
throw new ArgumentOutOfRangeException();
}
}
public T this[int index]
{
get
{
if (index < 0 || index >= k_MaxLength)
{
throw new ArgumentOutOfRangeException();
}
switch (index)
{
case 0:
return m_elem0;
case 1:
return m_elem1;
case 2:
return m_elem2;
case 3:
return m_elem3;
default:
throw new ArgumentOutOfRangeException();
}
}
internal set
{
if (index < 0 || index >= k_MaxLength)
{
throw new ArgumentOutOfRangeException();
}
switch (index)
{
case 0:
m_elem0 = value;
break;
case 1:
m_elem1 = value;
break;
case 2:
m_elem2 = value;
break;
case 3:
m_elem3 = value;
break;
default:
throw new ArgumentOutOfRangeException();
}
}
}
public int Length
{
get => m_Length;
}
public override string ToString()
{
switch (m_Length)
{
case 0:
return "[]";
case 1:
return $"[{m_elem0}]";
case 2:
return $"[{m_elem0}, {m_elem1}]";
case 3:
return $"[{m_elem0}, {m_elem1}, {m_elem2}]";
case 4:
return $"[{m_elem0}, {m_elem1}, {m_elem2}, {m_elem3}]";
default:
throw new ArgumentOutOfRangeException();
}
}
}
}

3
com.unity.ml-agents/Runtime/InplaceArray.cs.meta


fileFormatVersion: 2
guid: c1a80abee18a41c8aee89aeb33f5985d
timeCreated: 1615506199
正在加载...
取消
保存