using NUnit.Framework; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests { public class Float2DSensor : ISensor { public int Width { get; } public int Height { get; } string m_Name; private ObservationSpec m_ObservationSpec; public float[,] floatData; public Float2DSensor(int width, int height, string name) { Width = width; Height = height; m_Name = name; m_ObservationSpec = ObservationSpec.Visual(height, width, 1); floatData = new float[Height, Width]; } public Float2DSensor(float[,] floatData, string name) { this.floatData = floatData; Height = floatData.GetLength(0); Width = floatData.GetLength(1); m_Name = name; m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1); } public string GetName() { return m_Name; } public ObservationSpec GetObservationSpec() { return m_ObservationSpec; } public byte[] GetCompressedObservation() { return null; } public int Write(ObservationWriter writer) { using (TimerStack.Instance.Scoped("Float2DSensor.Write")) { for (var h = 0; h < Height; h++) { for (var w = 0; w < Width; w++) { writer[h, w, 0] = floatData[h, w]; } } var numWritten = Height * Width; return numWritten; } } public void Update() { } public void Reset() { } public CompressionSpec GetCompressionSpec() { return CompressionSpec.Default(); } } public class FloatVisualSensorTests { [Test] public void TestFloat2DSensorWrite() { var sensor = new Float2DSensor(3, 4, "floatsensor"); for (var h = 0; h < 4; h++) { for (var w = 0; w < 3; w++) { sensor.floatData[h, w] = 3 * h + w; } } var output = new float[12]; var writer = new ObservationWriter(); writer.SetTarget(output, sensor.GetObservationSpec(), 0); sensor.Write(writer); for (var i = 0; i < 9; i++) { Assert.AreEqual(i, output[i]); } } [Test] public void TestFloat2DSensorExternalData() { var data = new float[4, 3]; var sensor = new Float2DSensor(data, "floatsensor"); Assert.AreEqual(sensor.Height, 4); Assert.AreEqual(sensor.Width, 3); } } }