using System; using System.Collections.Generic; using UnityEngine; using UnityEngine.Profiling; namespace Unity.MLAgents.Sensors { /// /// The way the GridSensor process detected colliders in a cell. /// public enum ProcessCollidersMethod { /// /// Get data from all colliders detected in a cell /// ProcessAllColliders, /// /// Get data from the collider closest to the agent /// ProcessClosestColliders } /// /// Grid-based sensor. /// public class GridSensorBase : ISensor, IBuiltInSensor, IDisposable { string m_Name; Vector3 m_CellScale; Vector3Int m_GridSize; string[] m_DetectableTags; SensorCompressionType m_CompressionType; ObservationSpec m_ObservationSpec; internal IGridPerception m_GridPerception; // Buffers float[] m_PerceptionBuffer; Color[] m_PerceptionColors; Texture2D m_PerceptionTexture; float[] m_CellDataBuffer; // Utility Constants Calculated on Init int m_NumCells; int m_CellObservationSize; Vector3 m_CellCenterOffset; /// /// Create a GridSensorBase with the specified configuration. /// /// The sensor name /// The scale of each cell in the grid /// Number of cells on each side of the grid /// Tags to be detected by the sensor /// Compression type public GridSensorBase( string name, Vector3 cellScale, Vector3Int gridSize, string[] detectableTags, SensorCompressionType compression ) { m_Name = name; m_CellScale = cellScale; m_GridSize = gridSize; m_DetectableTags = detectableTags; CompressionType = compression; if (m_GridSize.y != 1) { throw new UnityAgentsException("GridSensor only supports 2D grids."); } m_NumCells = m_GridSize.x * m_GridSize.z; m_CellObservationSize = GetCellObservationSize(); m_ObservationSpec = ObservationSpec.Visual(m_GridSize.x, m_GridSize.z, m_CellObservationSize); m_PerceptionTexture = new Texture2D(m_GridSize.x, m_GridSize.z, TextureFormat.RGB24, false); ResetPerceptionBuffer(); } /// /// The compression type used by the sensor. /// public SensorCompressionType CompressionType { get { return m_CompressionType; } set { if (!IsDataNormalized() && value == SensorCompressionType.PNG) { Debug.LogWarning($"Compression type {value} is only supported with normalized data. " + "The sensor will not compress the data."); return; } m_CompressionType = value; } } internal float[] PerceptionBuffer { get { return m_PerceptionBuffer; } } /// /// The tags which the sensor dectects. /// protected string[] DetectableTags { get { return m_DetectableTags; } } /// public void Reset() { } /// /// Clears the perception buffer before loading in new data. /// public void ResetPerceptionBuffer() { if (m_PerceptionBuffer != null) { Array.Clear(m_PerceptionBuffer, 0, m_PerceptionBuffer.Length); Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length); } else { m_PerceptionBuffer = new float[m_CellObservationSize * m_NumCells]; m_CellDataBuffer = new float[m_CellObservationSize]; m_PerceptionColors = new Color[m_NumCells]; } } /// public string GetName() { return m_Name; } /// public CompressionSpec GetCompressionSpec() { return new CompressionSpec(CompressionType); } /// public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.GridSensor; } /// public byte[] GetCompressedObservation() { using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation")) { var allBytes = new List(); var numImages = (m_CellObservationSize + 2) / 3; for (int i = 0; i < numImages; i++) { var channelIndex = 3 * i; GridValuesToTexture(channelIndex, Math.Min(3, m_CellObservationSize - channelIndex)); allBytes.AddRange(m_PerceptionTexture.EncodeToPNG()); } return allBytes.ToArray(); } } /// /// Convert observation values to texture for PNG compression. /// void GridValuesToTexture(int channelIndex, int numChannelsToAdd) { for (int i = 0; i < m_NumCells; i++) { for (int j = 0; j < numChannelsToAdd; j++) { m_PerceptionColors[i][j] = m_PerceptionBuffer[i * m_CellObservationSize + channelIndex + j]; } } m_PerceptionTexture.SetPixels(m_PerceptionColors); } /// /// Get the observation values of the detected game object. /// Default is to record the detected tag index. /// /// This method can be overridden to encode the observation differently or get custom data from the object. /// When overriding this method, and /// might also need to change accordingly. /// /// The game object that was detected within a certain cell /// The index of the detectedObject's tag in the DetectableObjects list /// The buffer to write the observation values. /// The buffer size is configured by . /// /// /// Here is an example of overriding GetObjectData to get the velocity of a potential Rigidbody: /// /// protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer) /// { /// if (tagIndex == Array.IndexOf(DetectableTags, "RigidBodyObject")) /// { /// Rigidbody rigidbody = detectedObject.GetComponent<Rigidbody>(); /// dataBuffer[0] = rigidbody.velocity.x; /// dataBuffer[1] = rigidbody.velocity.y; /// dataBuffer[2] = rigidbody.velocity.z; /// } /// } /// /// protected virtual void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer) { dataBuffer[0] = tagIndex + 1; } /// /// Get the observation size for each cell. This will be the size of dataBuffer for . /// If overriding , override this method as well to the custom observation size. /// /// The observation size of each cell. protected virtual int GetCellObservationSize() { return 1; } /// /// Whether the data is normalized within [0, 1]. The sensor can only use PNG compression if the data is normailzed. /// If overriding , override this method as well according to the custom observation values. /// /// Bool value indicating whether data is normalized. protected virtual bool IsDataNormalized() { return false; } /// /// Whether to process all detected colliders in a cell. Default to false and only use the one closest to the agent. /// If overriding , consider override this method when needed. /// /// Bool value indicating whether to process all detected colliders in a cell. protected internal virtual ProcessCollidersMethod GetProcessCollidersMethod() { return ProcessCollidersMethod.ProcessClosestColliders; } /// /// If using PNG compression, check if the values are normalized. /// void ValidateValues(float[] dataValues, GameObject detectedObject) { if (m_CompressionType != SensorCompressionType.PNG) { return; } for (int j = 0; j < dataValues.Length; j++) { if (dataValues[j] < 0 || dataValues[j] > 1) throw new UnityAgentsException($"When using compression type {m_CompressionType} the data value has to be normalized between 0-1. " + $"Received value[{dataValues[j]}] for {detectedObject.name}"); } } /// /// Collect data from the detected object if a detectable tag is matched. /// internal void ProcessDetectedObject(GameObject detectedObject, int cellIndex) { Profiler.BeginSample("GridSensor.ProcessDetectedObject"); for (var i = 0; i < m_DetectableTags.Length; i++) { if (!ReferenceEquals(detectedObject, null) && detectedObject.CompareTag(m_DetectableTags[i])) { if (GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders) { Array.Copy(m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellDataBuffer, 0, m_CellObservationSize); } else { Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length); } GetObjectData(detectedObject, i, m_CellDataBuffer); ValidateValues(m_CellDataBuffer, detectedObject); Array.Copy(m_CellDataBuffer, 0, m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize); break; } } Profiler.EndSample(); } /// public void Update() { ResetPerceptionBuffer(); using (TimerStack.Instance.Scoped("GridSensor.Update")) { if (m_GridPerception != null) { m_GridPerception.Perceive(); } } } /// public ObservationSpec GetObservationSpec() { return m_ObservationSpec; } /// public int Write(ObservationWriter writer) { using (TimerStack.Instance.Scoped("GridSensor.Write")) { int index = 0; for (var h = m_GridSize.z - 1; h >= 0; h--) { for (var w = 0; w < m_GridSize.x; w++) { for (var d = 0; d < m_CellObservationSize; d++) { writer[h, w, d] = m_PerceptionBuffer[index]; index++; } } } return index; } } /// /// Clean up the internal objects. /// public void Dispose() { if (!ReferenceEquals(null, m_PerceptionTexture)) { Utilities.DestroyTexture(m_PerceptionTexture); m_PerceptionTexture = null; } } } }