浏览代码

[MLA-1634] Remove SensorComponent.GetObservationShape() (#5172)

/check-for-ModelOverriders
GitHub 3 年前
当前提交
354c37ca
共有 36 个文件被更改,包括 54 次插入257 次删除
  1. 6
      Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs
  2. 4
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  3. 16
      Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs
  4. 14
      com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
  5. 20
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
  6. 7
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  7. 6
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  8. 20
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  9. 35
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  10. 19
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs
  11. 15
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs
  12. 27
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs
  13. 6
      com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs
  14. 6
      com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs
  15. 1
      com.unity.ml-agents/CHANGELOG.md
  16. 2
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  17. 2
      com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
  18. 6
      com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
  19. 1
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  20. 15
      com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
  21. 14
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs
  22. 1
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  23. 1
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  24. 16
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs
  25. 7
      com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
  26. 1
      com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs
  27. 7
      com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs
  28. 1
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  29. 11
      com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
  30. 4
      com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs
  31. 7
      com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs
  32. 5
      com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs
  33. 2
      com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs
  34. 1
      com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
  35. 1
      com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs
  36. 4
      docs/Migrating.md

6
Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs


{
return new BasicSensor(basicController);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { BasicController.k_Extents };
}
}
/// <summary>

4
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


{
/// <summary>
/// Write the observations to the output buffer. This size of the buffer will be product
/// of the sizes returned by <see cref="GetObservationShape"/>.
/// of the Shape array values returned by <see cref="ObservationSpec"/>.
/// </summary>
/// <param name="output"></param>
public abstract void WriteObservation(float[] output);

/// <returns>The number of elements written.</returns>
public virtual int Write(ObservationWriter writer)
{
// TODO reuse buffer for similar agents, don't call GetObservationShape()
// TODO reuse buffer for similar agents
var numFloats = this.ObservationSize();
float[] buffer = new float[numFloats];
WriteObservation(buffer);

16
Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs


}
return m_Sensor;
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
var width = TestTexture.width;
var height = TestTexture.height;
var observationShape = new[] { height, width, 3 };
var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
if (stacks > 1)
{
observationShape[2] *= stacks;
}
return observationShape;
}
}

14
com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs


return new Match3Sensor(board, ObservationType, SensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
var board = GetComponent<AbstractBoard>();
if (board == null)
{
return System.Array.Empty<int>();
}
var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
return ObservationType == Match3ObservationType.Vector ?
new[] { board.Rows * board.Columns * (board.NumCellTypes + specialSize) } :
new[] { board.Rows, board.Columns, board.NumCellTypes + specialSize };
}
}
}

20
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs


return new PhysicsBodySensor(RootBody, Settings, sensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
{
numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}
}
}

7
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("GridSensor.WriteToTensor"))

6
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


}
#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null)
{
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_PoseExtractor = poseExtractor;

var numJointExtractorObservations = 0;
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
foreach (var articBody in poseExtractor.GetEnabledArticulationBodies())
{
var jointExtractor = new ArticulationBodyJointExtractor(articBody);
numJointExtractorObservations += jointExtractor.NumObservations(settings);

var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}
#endif
/// <inheritdoc/>

{
return BuiltInSensorType.PhysicsBodySensor;
}
}
}

20
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs


return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}
var poseExtractor = GetPoseExtractor();
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
foreach (var rb in poseExtractor.GetEnabledRigidbodies())
{
var joint = rb.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}
/// <summary>
/// Get the DisplayNodes of the hierarchy.
/// </summary>

35
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3 * 3 * 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3 * 3 * (2 + 3));
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
var expectedObs = new float[]
{

};
SensorTestHelper.CompareObservation(sensor, expectedObs);
}
[Test]
public void TestVisualObservations()

sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2 + 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

Assert.AreEqual(expectedPng, pngData);
}
[Test]
public void TestCompressedVisualObservationsSpecial()
{

sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2 + 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

}
var expectedPng = LoadPNGs(pathPrefix, 2);
Assert.AreEqual(expectedPng, concatenatedPngData);
}
/// <summary>

}
return bytesOut.ToArray();
}
}
}

19
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelHotShapeTests.cs


1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 1);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}

1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 2);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}
[Test]

1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();
int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 3);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}

1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();
int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 6);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}
}

15
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ChannelShapeTests.cs


1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();
int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 1);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}
[Test]

1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();
int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 2);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}
[Test]

1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();
int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 7);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}
}
}

27
com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs


return duplicated;
}
/// <summary>
/// Asserts that 2 int arrays are the same
/// </summary>
/// <param name="expected">The expected array</param>
/// <param name="actual">The actual array</param>
public static void AssertArraysAreEqual(int[] expected, int[] actual)
{
Assert.AreEqual(expected.Length, actual.Length, "Lengths are not the same");
for (int i = 0; i < actual.Length; i++)
{
Assert.AreEqual(expected[i], actual[i], "Got " + Array2Str(actual) + ", expected " + Array2Str(expected));
}
}
/// <summary>
/// Asserts that 2 float arrays are the same
/// </summary>
/// <param name="expected">The expected array</param>
/// <param name="actual">The actual array</param>
public static void AssertArraysAreEqual(float[] expected, float[] actual)
{
Assert.AreEqual(expected.Length, actual.Length, "Lengths are not the same");
for (int i = 0; i < actual.Length; i++)
{
Assert.AreEqual(expected[i], actual[i], "Got " + Array2Str(actual) + ", expected " + Array2Str(expected));
}
}
/// <summary>
/// Asserts that the sub-arrays of the total array are equal to specific subarrays at specific subarray indicies and equal to a default everywhere else.

6
com.unity.ml-agents.extensions/Tests/Runtime/Sensors/ArticulationBodySensorTests.cs


0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
}
[Test]

#endif
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings

0f, // joint2.force
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
}
}
}

6
com.unity.ml-agents.extensions/Tests/Runtime/Sensors/RigidBodySensorTests.cs


// The root body is ignored since it always generates identity values
// and there are no other bodies to generate observations.
var expected = new float[0];
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
SensorTestHelper.CompareObservation(sensor, expected);
}

-1f, 1f, 0f, // Attached vel
0f, -1f, 1f // Leaf vel
};
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
SensorTestHelper.CompareObservation(sensor, expected);
// Update the settings to only process joint observations

0f, 0f, 0f, // joint2.torque
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
Assert.AreEqual(expected.Length, sensorComponent.CreateSensor().GetObservationSpec().Shape[0]);
}
}

1
com.unity.ml-agents/CHANGELOG.md


and `IDimensionPropertiesSensor` interfaces were removed. (#5127)
- `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor`
interface was removed. (#5164)
- The abstract method `SensorComponent.GetObservationShape()` was no longer being called, so it has been removed. (#5172)
#### ml-agents / ml-agents-envs / gym-unity (Python)

2
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits,
};
}
#endregion
#endregion
}
}

2
com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs


{
return BuiltInSensorType.BufferSensor;
}
}

6
com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs


return m_Sensor;
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { MaxNumObservables, ObservableSize };
}
/// <summary>
/// Appends an observation to the buffer. If the buffer is full (maximum number
/// of observation is reached) the observation will be ignored. the length of

1
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


{
return BuiltInSensorType.CameraSensor;
}
}
}

15
com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs


}
/// <summary>
/// Computes the observation shape of the sensor.
/// </summary>
/// <returns>The observation shape of the associated <see cref="CameraSensor"/> object.</returns>
public override int[] GetObservationShape()
{
var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
var cameraSensorshape = CameraSensor.GenerateShape(m_Width, m_Height, Grayscale);
if (stacks > 1)
{
cameraSensorshape[cameraSensorshape.Length - 1] *= stacks;
}
return cameraSensorshape;
}
/// <summary>
/// Update fields that are safe to change on the Sensor at runtime.
/// </summary>
internal void UpdateSensor()

14
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs


}
/// <summary>
/// Returns the observation shape for this raycast sensor which depends on the number
/// of tags for detected objects and the number of rays.
/// </summary>
/// <returns></returns>
public override int[] GetObservationShape()
{
var numRays = 2 * RaysPerDirection + 1;
var numTags = m_DetectableTags?.Count ?? 0;
var obsSize = (numTags + 2) * numRays;
var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
return new[] { obsSize * stacks };
}
/// <summary>
/// Get the RayPerceptionInput that is used by the <see cref="RayPerceptionSensor"/>.
/// </summary>
/// <returns></returns>

1
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


{
return BuiltInSensorType.ReflectionSensor;
}
}
}

1
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


return BuiltInSensorType.RenderTextureSensor;
}
/// <summary>
/// Converts a RenderTexture to a 2D texture.
/// </summary>

16
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensorComponent.cs


return m_Sensor;
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
var width = RenderTexture != null ? RenderTexture.width : 0;
var height = RenderTexture != null ? RenderTexture.height : 0;
var observationShape = new[] { height, width, Grayscale ? 1 : 3 };
var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
if (stacks > 1)
{
observationShape[2] *= stacks;
}
return observationShape;
}
/// <summary>
/// Update fields that are safe to change on the Sensor at runtime.
/// </summary>

7
com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs


/// </summary>
/// <returns>Created ISensor object.</returns>
public abstract ISensor CreateSensor();
/// <summary>
/// Returns the shape of the sensor observations that will be created.
/// </summary>
/// <returns>Shape of the sensor observation.</returns>
public abstract int[] GetObservationShape();
}
}

1
com.unity.ml-agents/Tests/Editor/Analytics/TrainingAnalyticsTest.cs


#else
Assert.IsFalse(TrainingAnalytics.EnableAnalytics());
#endif
}
}
}

7
com.unity.ml-agents/Tests/Editor/Inference/ParameterLoaderTest.cs


{
return Sensor;
}
}
public override int[] GetObservationShape()
{
var shape = Sensor.GetObservationSpec().Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
}
public class Test3DSensor : ISensor, IBuiltInSensor
{
int m_Width;

1
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


namespace Unity.MLAgents.Tests
{
[TestFixture]
public class EditModeTestGeneration
{

11
com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs


var wrappedSensor = wrappedComponent.CreateSensor();
return new StackingSensor(wrappedSensor, numStacks);
}
public override int[] GetObservationShape()
{
int[] shape = (int[])wrappedComponent.GetObservationShape().Clone();
for (var i = 0; i < shape.Length; i++)
{
shape[i] *= numStacks;
}
return shape;
}
}
public class RuntimeApiTest

4
com.unity.ml-agents/Tests/Runtime/Sensor/BufferSensorTest.cs


bufferComponent.SensorName = "TestName";
var sensor = bufferComponent.CreateSensor();
var shape = bufferComponent.GetObservationShape();
var shape = sensor.GetObservationSpec().Shape;
Assert.AreEqual(shape[0], 20);
Assert.AreEqual(shape[1], 4);

var obsWriter = new ObservationWriter();
var obs = sensor.GetObservationProto(obsWriter);
Assert.AreEqual(shape, obs.Shape);
Assert.AreEqual(shape, InplaceArray<int>.FromList(obs.Shape));
Assert.AreEqual(obs.DimensionProperties.Count, 2);
Assert.AreEqual(sensor.GetName(), "TestName");

7
com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorComponentTest.cs


cameraComponent.Grayscale = grayscale;
cameraComponent.CompressionType = compression;
var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
var expectedShapeInplace = new InplaceArray<int>(height, width, grayscale ? 1 : 3);
Assert.AreEqual(expectedShapeInplace, sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(height, width, grayscale ? 1 : 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(CameraSensor), sensor.GetType());
}
}

5
com.unity.ml-agents/Tests/Runtime/Sensor/RenderTextureSensorComponentTests.cs


renderTexComponent.Grayscale = grayscale;
renderTexComponent.CompressionType = compression;
var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
var expectedShape = new InplaceArray<int>(height, width, grayscale ? 1 : 3);
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);
Assert.AreEqual(typeof(RenderTextureSensor), sensor.GetType());
}
}

2
com.unity.ml-agents/Tests/Runtime/Sensor/SensorShapeValidatorTests.cs


public class SensorShapeValidatorTests
{
[Test]
public void TestShapesAgree()
{

LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3");
validator.ValidateSensors(sensorList1);
}
[Test]
public void TestDimensionMismatch()
{

1
com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs


{
return "Dummy";
}
}
[Test]

1
com.unity.ml-agents/Tests/Runtime/Utils/TestClasses.cs


public class TestClasses
{
}
}

4
docs/Migrating.md


actionMask.SetActionEnabled(branch, 3, false);
}
```
### IActuator changes
### ISensor and SensorComponent changes
- The `ISensor.GetObservationShape()` method and `ITypedSensor`
and `IDimensionPropertiesSensor` interfaces were removed, and `GetObservationSpec()` was added. You can use
`ObservationSpec.Vector()` or `ObservationSpec.Visual()` to generate `ObservationSpec`s that are equivalent to

return CompressionSpec.Default();
}
```
- The abstract method `SensorComponent.GetObservationShape()` was removed.
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations

正在加载...
取消
保存