using System.Collections.Generic; using System.IO; using NUnit.Framework; using Unity.MLAgents.Extensions.Match3; using UnityEngine; using Unity.MLAgents.Extensions.Tests.Sensors; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Extensions.Tests.Match3 { public class Match3SensorTests { // Whether the expected PNG data should be written to a file. // Only set this to true if the compressed observation format changes. private bool WritePNGDataToFile = false; private const string k_CellObservationPng = "match3obs"; private const string k_SpecialObservationPng = "match3obs_special"; [Test] public void TestVectorObservations() { var boardString = @"000 000 010"; var gameObj = new GameObject("board"); var board = gameObj.AddComponent(); board.SetBoard(boardString); var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.Vector; var sensor = sensorComponent.CreateSensors()[0]; var expectedShape = new InplaceArray(3 * 3 * 2); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); var expectedObs = new float[] { 1, 0, /**/ 0, 1, /**/ 1, 0, 1, 0, /**/ 1, 0, /**/ 1, 0, 1, 0, /**/ 1, 0, /**/ 1, 0, }; SensorTestHelper.CompareObservation(sensor, expectedObs); } [Test] public void TestVectorObservationsSpecial() { var boardString = @"000 000 010"; var specialString = @"010 200 000"; var gameObj = new GameObject("board"); var board = gameObj.AddComponent(); board.SetBoard(boardString); board.SetSpecial(specialString); var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.Vector; var sensors = sensorComponent.CreateSensors(); var cellSensor = sensors[0]; var specialSensor = sensors[1]; { var expectedShape = new InplaceArray(3 * 3 * 2); Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); var expectedObs = new float[] { 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ }; SensorTestHelper.CompareObservation(cellSensor, expectedObs); } { var expectedShape = new InplaceArray(3 * 3 * 3); Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); var expectedObs = new float[] { 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ }; SensorTestHelper.CompareObservation(specialSensor, expectedObs); } } [Test] public void TestVisualObservations() { var boardString = @"000 000 010"; var gameObj = new GameObject("board"); var board = gameObj.AddComponent(); board.SetBoard(boardString); var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; var sensor = sensorComponent.CreateSensors()[0]; var expectedShape = new InplaceArray(3, 3, 2); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType); var expectedObs = new float[] { 1, 0, /**/ 0, 1, /**/ 1, 0, 1, 0, /**/ 1, 0, /**/ 1, 0, 1, 0, /**/ 1, 0, /**/ 1, 0, }; SensorTestHelper.CompareObservation(sensor, expectedObs); var expectedObs3D = new float[,,] { {{1, 0}, {0, 1}, {1, 0}}, {{1, 0}, {1, 0}, {1, 0}}, {{1, 0}, {1, 0}, {1, 0}}, }; SensorTestHelper.CompareObservation(sensor, expectedObs3D); } [Test] public void TestVisualObservationsSpecial() { var boardString = @"000 000 010"; var specialString = @"010 200 000"; var gameObj = new GameObject("board"); var board = gameObj.AddComponent(); board.SetBoard(boardString); board.SetSpecial(specialString); var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual; var sensors = sensorComponent.CreateSensors(); var cellSensor = sensors[0]; var specialSensor = sensors[1]; { var expectedShape = new InplaceArray(3, 3, 2); Assert.AreEqual(expectedShape, cellSensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, cellSensor.GetCompressionSpec().SensorCompressionType); var expectedObs = new float[] { 1, 0, /* (0) */ 0, 1, /* (1) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ 1, 0, /* (0) */ }; SensorTestHelper.CompareObservation(cellSensor, expectedObs); var expectedObs3D = new float[,,] { {{1, 0}, {0, 1}, {1, 0}}, {{1, 0}, {1, 0}, {1, 0}}, {{1, 0}, {1, 0}, {1, 0}}, }; SensorTestHelper.CompareObservation(cellSensor, expectedObs3D); } { var expectedShape = new InplaceArray(3, 3, 3); Assert.AreEqual(expectedShape, specialSensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.None, specialSensor.GetCompressionSpec().SensorCompressionType); var expectedObs = new float[] { 1, 0, 0, /* (0) */ 1, 0, 0, /* (1) */ 1, 0, 0, /* (0) */ 0, 0, 1, /* (2) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ 1, 0, 0, /* (0) */ 0, 1, 0, /* (1) */ 1, 0, 0, /* (0) */ }; SensorTestHelper.CompareObservation(specialSensor, expectedObs); var expectedObs3D = new float[,,] { {{1, 0, 0}, {1, 0, 0}, {1, 0, 0}}, {{0, 0, 1}, {1, 0, 0}, {1, 0, 0}}, {{1, 0, 0}, {0, 1, 0}, {1, 0, 0}}, }; SensorTestHelper.CompareObservation(specialSensor, expectedObs3D); } } [Test] public void TestCompressedVisualObservations() { var boardString = @"000 000 010"; var gameObj = new GameObject("board"); var board = gameObj.AddComponent(); board.SetBoard(boardString); var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; var sensor = sensorComponent.CreateSensors()[0]; var expectedShape = new InplaceArray(3, 3, 2); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); var pngData = sensor.GetCompressedObservation(); if (WritePNGDataToFile) { // Enable this if the format of the observation changes SavePNGs(pngData, k_CellObservationPng); } var expectedPng = LoadPNGs(k_CellObservationPng, 1); Assert.AreEqual(expectedPng, pngData); } [Test] public void TestCompressedVisualObservationsSpecial() { var boardString = @"000 000 010"; var specialString = @"010 200 000"; var gameObj = new GameObject("board"); var board = gameObj.AddComponent(); board.SetBoard(boardString); board.SetSpecial(specialString); var sensorComponent = gameObj.AddComponent(); sensorComponent.ObservationType = Match3ObservationType.CompressedVisual; var sensors = sensorComponent.CreateSensors(); var paths = new[] { k_CellObservationPng, k_SpecialObservationPng }; var expectedChannels = new[] { 2, 3 }; for (var i = 0; i < 2; i++) { var sensor = sensors[i]; var expectedShape = new InplaceArray(3, 3, expectedChannels[i]); Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape); Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType); var pngData = sensor.GetCompressedObservation(); if (WritePNGDataToFile) { // Enable this if the format of the observation changes SavePNGs(pngData, paths[i]); } var expectedPng = LoadPNGs(paths[i], 1); Assert.AreEqual(expectedPng, pngData); } } /// /// Helper method for un-concatenating PNG observations. /// /// /// List SplitPNGs(byte[] concatenated) { var pngsOut = new List(); var pngHeader = new byte[] { 137, 80, 78, 71, 13, 10, 26, 10 }; var current = new List(); for (var i = 0; i < concatenated.Length; i++) { current.Add(concatenated[i]); // Check if the header starts at the next position // If so, we'll start a new output array. var headerIsNext = false; if (i + 1 < concatenated.Length - pngHeader.Length) { for (var j = 0; j < pngHeader.Length; j++) { if (concatenated[i + 1 + j] != pngHeader[j]) { break; } if (j == pngHeader.Length - 1) { headerIsNext = true; } } } if (headerIsNext) { pngsOut.Add(current.ToArray()); current = new List(); } } pngsOut.Add(current.ToArray()); return pngsOut; } void SavePNGs(byte[] concatenatedPngData, string pathPrefix) { var splitPngs = SplitPNGs(concatenatedPngData); for (var i = 0; i < splitPngs.Count; i++) { var pngData = splitPngs[i]; var path = $"Packages/com.unity.ml-agents.extensions/Tests/Editor/Match3/{pathPrefix}{i}.png"; using (var sw = File.Create(path)) { foreach (var b in pngData) { sw.WriteByte(b); } } } } byte[] LoadPNGs(string pathPrefix, int numExpected) { var bytesOut = new List(); for (var i = 0; i < numExpected; i++) { var path = $"Packages/com.unity.ml-agents.extensions/Tests/Editor/Match3/{pathPrefix}{i}.png"; var res = File.ReadAllBytes(path); bytesOut.AddRange(res); } return bytesOut.ToArray(); } } }