using System.Collections.Generic; using Unity.MLAgents.Sensors; using UnityEngine; using Debug = UnityEngine.Debug; namespace Unity.MLAgents.Extensions.Match3 { /// /// Delegate that provides integer values at a given (x,y) coordinate. /// /// /// public delegate int GridValueProvider(int x, int y); /// /// Type of observations to generate. /// /// public enum Match3ObservationType { /// /// Generate a one-hot encoding of the cell type for each cell on the board. If there are special types, /// these will also be one-hot encoded. /// Vector, /// /// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as /// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded. /// UncompressedVisual, /// /// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as /// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded. /// During training, these will be sent as a concatenated series of PNG images, with 3 channels per image. /// CompressedVisual } /// /// Sensor for Match3 games. Can generate either vector, compressed visual, /// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values. /// public class Match3Sensor : ISensor, IBuiltInSensor { Match3ObservationType m_ObservationType; ObservationSpec m_ObservationSpec; string m_Name; AbstractBoard m_Board; BoardSize m_MaxBoardSize; GridValueProvider m_GridValues; int m_OneHotSize; /// /// Create a sensor for the GridValueProvider with the specified observation type. /// /// /// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling /// the constructor directly. /// /// The abstract board. /// The GridValueProvider, should be either board.GetCellType or board.GetSpecialType. /// The number of possible values that the GridValueProvider can return. /// Whether to produce vector or visual observations /// Name of the sensor. public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name) { var maxBoardSize = board.GetMaxBoardSize(); m_Name = name; m_MaxBoardSize = maxBoardSize; m_GridValues = gvp; m_OneHotSize = oneHotSize; m_Board = board; m_ObservationType = obsType; m_ObservationSpec = obsType == Match3ObservationType.Vector ? ObservationSpec.Vector(maxBoardSize.Rows * maxBoardSize.Columns * oneHotSize) : ObservationSpec.Visual(maxBoardSize.Rows, maxBoardSize.Columns, oneHotSize); } /// /// Create a sensor that encodes the board cells as observations. /// /// The abstract board. /// Whether to produce vector or visual observations /// Name of the sensor. /// public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) { var maxBoardSize = board.GetMaxBoardSize(); return new Match3Sensor(board, board.GetCellType, maxBoardSize.NumCellTypes, obsType, name); } /// /// Create a sensor that encodes the cell special types as observations. Returns null if the board's /// NumSpecialTypes is 0 (indicating the sensor isn't needed). /// /// The abstract board. /// Whether to produce vector or visual observations /// Name of the sensor. /// public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) { var maxBoardSize = board.GetMaxBoardSize(); if (maxBoardSize.NumSpecialTypes == 0) { return null; } var specialSize = maxBoardSize.NumSpecialTypes + 1; return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name); } /// public ObservationSpec GetObservationSpec() { return m_ObservationSpec; } /// public int Write(ObservationWriter writer) { m_Board.CheckBoardSizes(m_MaxBoardSize); var currentBoardSize = m_Board.GetCurrentBoardSize(); int offset = 0; var isVisual = m_ObservationType != Match3ObservationType.Vector; // This is equivalent to // for (var r = 0; r < m_MaxBoardSize.Rows; r++) // for (var c = 0; c < m_MaxBoardSize.Columns; c++) // if (r < currentBoardSize.Rows && c < currentBoardSize.Columns) // WriteOneHot // else // WriteZero // but rearranged to avoid the branching. for (var r = 0; r < currentBoardSize.Rows; r++) { for (var c = 0; c < currentBoardSize.Columns; c++) { var val = m_GridValues(r, c); writer.WriteOneHot(offset, r, c, val, m_OneHotSize, isVisual); offset += m_OneHotSize; } for (var c = currentBoardSize.Columns; c < m_MaxBoardSize.Columns; c++) { writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); offset += m_OneHotSize; } } for (var r = currentBoardSize.Rows; r < m_MaxBoardSize.Columns; r++) { for (var c = 0; c < m_MaxBoardSize.Columns; c++) { writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); offset += m_OneHotSize; } } return offset; } /// public byte[] GetCompressedObservation() { m_Board.CheckBoardSizes(m_MaxBoardSize); var height = m_MaxBoardSize.Rows; var width = m_MaxBoardSize.Columns; var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false); var converter = new OneHotToTextureUtil(height, width); var bytesOut = new List(); var currentBoardSize = m_Board.GetCurrentBoardSize(); // Encode the cell types or special types as batches of PNGs // This is potentially wasteful, e.g. if there are 4 cell types and 1 special type, we could // fit in in 2 images, but we'll use 3 total (2 PNGs for the 4 cell type channels, and 1 for // the special types). var numCellImages = (m_OneHotSize + 2) / 3; for (var i = 0; i < numCellImages; i++) { converter.EncodeToTexture( m_GridValues, tempTexture, 3 * i, currentBoardSize.Rows, currentBoardSize.Columns ); bytesOut.AddRange(tempTexture.EncodeToPNG()); } DestroyTexture(tempTexture); return bytesOut.ToArray(); } /// public void Update() { } /// public void Reset() { } internal SensorCompressionType GetCompressionType() { return m_ObservationType == Match3ObservationType.CompressedVisual ? SensorCompressionType.PNG : SensorCompressionType.None; } /// public CompressionSpec GetCompressionSpec() { return new CompressionSpec(GetCompressionType()); } /// public string GetName() { return m_Name; } /// public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.Match3Sensor; } static void DestroyTexture(Texture2D texture) { if (Application.isEditor) { // Edit Mode tests complain if we use Destroy() Object.DestroyImmediate(texture); } else { Object.Destroy(texture); } } } /// /// Utility class for converting a 2D array of ints representing a one-hot encoding into /// a texture, suitable for conversion to PNGs for observations. /// Works by encoding 3 values at a time as pixels in the texture, thus it should be /// called (maxValue + 2) / 3 times, increasing the channelOffset by 3 each time. /// internal class OneHotToTextureUtil { Color[] m_Colors; int m_MaxHeight; int m_MaxWidth; private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue }; public OneHotToTextureUtil(int maxHeight, int maxWidth) { m_Colors = new Color[maxHeight * maxWidth]; m_MaxHeight = maxHeight; m_MaxWidth = maxWidth; } public void EncodeToTexture( GridValueProvider gridValueProvider, Texture2D texture, int channelOffset, int currentHeight, int currentWidth ) { var i = 0; // There's an implicit flip converting to PNG from texture, so make sure we // counteract that when forming the texture by iterating through h in reverse. for (var h = m_MaxHeight - 1; h >= 0; h--) { for (var w = 0; w < m_MaxWidth; w++) { var colorVal = Color.black; if (h < currentHeight && w < currentWidth) { int oneHotValue = gridValueProvider(h, w); if (oneHotValue >= channelOffset && oneHotValue < channelOffset + 3) { colorVal = s_OneHotColors[oneHotValue - channelOffset]; } } m_Colors[i++] = colorVal; } } texture.SetPixels(m_Colors); } } /// /// Utility methods for writing one-hot observations. /// internal static class ObservationWriterMatch3Extensions { public static void WriteOneHot(this ObservationWriter writer, int offset, int row, int col, int value, int oneHotSize, bool isVisual) { if (isVisual) { for (var i = 0; i < oneHotSize; i++) { writer[row, col, i] = (i == value) ? 1.0f : 0.0f; } } else { for (var i = 0; i < oneHotSize; i++) { writer[offset] = (i == value) ? 1.0f : 0.0f; offset++; } } } public static void WriteZero(this ObservationWriter writer, int offset, int row, int col, int oneHotSize, bool isVisual) { if (isVisual) { for (var i = 0; i < oneHotSize; i++) { writer[row, col, i] = 0.0f; } } else { for (var i = 0; i < oneHotSize; i++) { writer[offset] = 0.0f; offset++; } } } } }