您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
231 行
9.7 KiB
231 行
9.7 KiB
using NUnit.Framework;
|
|
using System;
|
|
using System.Linq;
|
|
using UnityEngine;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class StackingSensorTests
|
|
{
|
|
[Test]
|
|
public void TestCtor()
|
|
{
|
|
ISensor wrapped = new VectorSensor(4);
|
|
ISensor sensor = new StackingSensor(wrapped, 4);
|
|
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
|
|
Assert.AreEqual(sensor.GetObservationSpec().Shape, new[] { 16 });
|
|
}
|
|
|
|
[Test]
|
|
public void TestVectorStacking()
|
|
{
|
|
VectorSensor wrapped = new VectorSensor(2);
|
|
ISensor sensor = new StackingSensor(wrapped, 3);
|
|
|
|
wrapped.AddObservation(new[] { 1f, 2f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f });
|
|
|
|
sensor.Update();
|
|
wrapped.AddObservation(new[] { 3f, 4f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f });
|
|
|
|
sensor.Update();
|
|
wrapped.AddObservation(new[] { 5f, 6f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f });
|
|
|
|
sensor.Update();
|
|
wrapped.AddObservation(new[] { 7f, 8f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f });
|
|
|
|
sensor.Update();
|
|
wrapped.AddObservation(new[] { 9f, 10f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
|
|
|
|
// Check that if we don't call Update(), the same observations are produced
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestVectorStackingReset()
|
|
{
|
|
VectorSensor wrapped = new VectorSensor(2);
|
|
ISensor sensor = new StackingSensor(wrapped, 3);
|
|
|
|
wrapped.AddObservation(new[] { 1f, 2f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f });
|
|
|
|
sensor.Update();
|
|
wrapped.AddObservation(new[] { 3f, 4f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f });
|
|
|
|
sensor.Reset();
|
|
wrapped.AddObservation(new[] { 5f, 6f });
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 5f, 6f });
|
|
}
|
|
|
|
class Dummy3DSensor : ISparseChannelSensor
|
|
{
|
|
public SensorCompressionType CompressionType = SensorCompressionType.PNG;
|
|
public int[] Mapping;
|
|
public ObservationSpec ObservationSpec;
|
|
public float[,,] CurrentObservation;
|
|
|
|
internal Dummy3DSensor()
|
|
{
|
|
}
|
|
|
|
public ObservationSpec GetObservationSpec()
|
|
{
|
|
return ObservationSpec;
|
|
}
|
|
|
|
public int Write(ObservationWriter writer)
|
|
{
|
|
for (var h = 0; h < ObservationSpec.Shape[0]; h++)
|
|
{
|
|
for (var w = 0; w < ObservationSpec.Shape[1]; w++)
|
|
{
|
|
for (var c = 0; c < ObservationSpec.Shape[2]; c++)
|
|
{
|
|
writer[h, w, c] = CurrentObservation[h, w, c];
|
|
}
|
|
}
|
|
}
|
|
return ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2];
|
|
}
|
|
|
|
public byte[] GetCompressedObservation()
|
|
{
|
|
var writer = new ObservationWriter();
|
|
var flattenedObservation = new float[ObservationSpec.Shape[0] * ObservationSpec.Shape[1] * ObservationSpec.Shape[2]];
|
|
writer.SetTarget(flattenedObservation, ObservationSpec.Shape, 0);
|
|
Write(writer);
|
|
byte[] bytes = Array.ConvertAll(flattenedObservation, (z) => (byte)z);
|
|
return bytes;
|
|
}
|
|
|
|
public void Update() { }
|
|
|
|
public void Reset() { }
|
|
|
|
public SensorCompressionType GetCompressionType()
|
|
{
|
|
return CompressionType;
|
|
}
|
|
|
|
public string GetName()
|
|
{
|
|
return "Dummy";
|
|
}
|
|
|
|
public int[] GetCompressedChannelMapping()
|
|
{
|
|
return Mapping;
|
|
}
|
|
|
|
}
|
|
|
|
[Test]
|
|
public void TestStackingMapping()
|
|
{
|
|
// Test grayscale stacked mapping with CameraSensor
|
|
var cameraSensor = new CameraSensor(new Camera(), 64, 64,
|
|
true, "grayscaleCamera", SensorCompressionType.PNG);
|
|
var stackedCameraSensor = new StackingSensor(cameraSensor, 2);
|
|
Assert.AreEqual(stackedCameraSensor.GetCompressedChannelMapping(), new[] { 0, 0, 0, 1, 1, 1 });
|
|
|
|
// Test RGB stacked mapping with RenderTextureSensor
|
|
var renderTextureSensor = new RenderTextureSensor(new RenderTexture(24, 16, 0),
|
|
false, "renderTexture", SensorCompressionType.PNG);
|
|
var stackedRenderTextureSensor = new StackingSensor(renderTextureSensor, 2);
|
|
Assert.AreEqual(stackedRenderTextureSensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, 4, 5 });
|
|
|
|
// Test mapping with number of layers not being multiple of 3
|
|
var dummySensor = new Dummy3DSensor();
|
|
dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
|
|
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
|
|
var stackedDummySensor = new StackingSensor(dummySensor, 2);
|
|
Assert.AreEqual(stackedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
|
|
|
|
// Test mapping with dummy layers that should be dropped
|
|
var paddedDummySensor = new Dummy3DSensor();
|
|
paddedDummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
|
|
paddedDummySensor.Mapping = new[] { 0, 1, 2, 3, -1, -1 };
|
|
var stackedPaddedDummySensor = new StackingSensor(paddedDummySensor, 2);
|
|
Assert.AreEqual(stackedPaddedDummySensor.GetCompressedChannelMapping(), new[] { 0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1 });
|
|
}
|
|
|
|
[Test]
|
|
public void Test3DStacking()
|
|
{
|
|
var wrapped = new Dummy3DSensor();
|
|
wrapped.ObservationSpec = ObservationSpec.FromShape(2, 1, 2);
|
|
var sensor = new StackingSensor(wrapped, 2);
|
|
|
|
// Check the stacking is on the last dimension
|
|
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f } }, { { 3f, 4f } } };
|
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 1f, 2f } }, { { 0f, 0f, 3f, 4f } } });
|
|
|
|
sensor.Update();
|
|
wrapped.CurrentObservation = new[, ,] { { { 5f, 6f } }, { { 7f, 8f } } };
|
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 1f, 2f, 5f, 6f } }, { { 3f, 4f, 7f, 8f } } });
|
|
|
|
sensor.Update();
|
|
wrapped.CurrentObservation = new[, ,] { { { 9f, 10f } }, { { 11f, 12f } } };
|
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } });
|
|
|
|
// Check that if we don't call Update(), the same observations are produced
|
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 5f, 6f, 9f, 10f } }, { { 7f, 8f, 11f, 12f } } });
|
|
|
|
// Test reset
|
|
sensor.Reset();
|
|
wrapped.CurrentObservation = new[, ,] { { { 13f, 14f } }, { { 15f, 16f } } };
|
|
SensorTestHelper.CompareObservation(sensor, new[, ,] { { { 0f, 0f, 13f, 14f } }, { { 0f, 0f, 15f, 16f } } });
|
|
}
|
|
|
|
[Test]
|
|
public void TestStackedGetCompressedObservation()
|
|
{
|
|
var wrapped = new Dummy3DSensor();
|
|
wrapped.ObservationSpec = ObservationSpec.FromShape(1, 1, 3);
|
|
var sensor = new StackingSensor(wrapped, 2);
|
|
|
|
wrapped.CurrentObservation = new[, ,] { { { 1f, 2f, 3f } } };
|
|
var expected1 = sensor.CreateEmptyPNG();
|
|
expected1 = expected1.Concat(Array.ConvertAll(new[] { 1f, 2f, 3f }, (z) => (byte)z)).ToArray();
|
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected1);
|
|
|
|
sensor.Update();
|
|
wrapped.CurrentObservation = new[, ,] { { { 4f, 5f, 6f } } };
|
|
var expected2 = Array.ConvertAll(new[] { 1f, 2f, 3f, 4f, 5f, 6f }, (z) => (byte)z);
|
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected2);
|
|
|
|
sensor.Update();
|
|
wrapped.CurrentObservation = new[, ,] { { { 7f, 8f, 9f } } };
|
|
var expected3 = Array.ConvertAll(new[] { 4f, 5f, 6f, 7f, 8f, 9f }, (z) => (byte)z);
|
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected3);
|
|
|
|
// Test reset
|
|
sensor.Reset();
|
|
wrapped.CurrentObservation = new[, ,] { { { 10f, 11f, 12f } } };
|
|
var expected4 = sensor.CreateEmptyPNG();
|
|
expected4 = expected4.Concat(Array.ConvertAll(new[] { 10f, 11f, 12f }, (z) => (byte)z)).ToArray();
|
|
Assert.AreEqual(sensor.GetCompressedObservation(), expected4);
|
|
}
|
|
|
|
[Test]
|
|
public void TestStackingSensorBuiltInSensorType()
|
|
{
|
|
var dummySensor = new Dummy3DSensor();
|
|
dummySensor.ObservationSpec = ObservationSpec.FromShape(2, 2, 4);
|
|
dummySensor.Mapping = new[] { 0, 1, 2, 3 };
|
|
var stackedDummySensor = new StackingSensor(dummySensor, 2);
|
|
Assert.AreEqual(stackedDummySensor.GetBuiltInSensorType(), BuiltInSensorType.Unknown);
|
|
|
|
var vectorSensor = new VectorSensor(4);
|
|
var stackedVectorSensor = new StackingSensor(vectorSensor, 4);
|
|
Assert.AreEqual(stackedVectorSensor.GetBuiltInSensorType(), BuiltInSensorType.VectorSensor);
|
|
}
|
|
}
|
|
}
|