using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Profiling;
namespace Unity.MLAgents.Sensors
{
///
/// The way the GridSensor process detected colliders in a cell.
///
public enum ProcessCollidersMethod
{
///
/// Get data from all colliders detected in a cell
///
ProcessAllColliders,
///
/// Get data from the collider closest to the agent
///
ProcessClosestColliders
}
///
/// Grid-based sensor.
///
public class GridSensorBase : ISensor, IBuiltInSensor, IDisposable
{
string m_Name;
Vector3 m_CellScale;
Vector3Int m_GridSize;
string[] m_DetectableTags;
SensorCompressionType m_CompressionType;
ObservationSpec m_ObservationSpec;
internal IGridPerception m_GridPerception;
// Buffers
float[] m_PerceptionBuffer;
Color[] m_PerceptionColors;
Texture2D m_PerceptionTexture;
float[] m_CellDataBuffer;
// Utility Constants Calculated on Init
int m_NumCells;
int m_CellObservationSize;
Vector3 m_CellCenterOffset;
///
/// Create a GridSensorBase with the specified configuration.
///
/// The sensor name
/// The scale of each cell in the grid
/// Number of cells on each side of the grid
/// Tags to be detected by the sensor
/// Compression type
public GridSensorBase(
string name,
Vector3 cellScale,
Vector3Int gridSize,
string[] detectableTags,
SensorCompressionType compression
)
{
m_Name = name;
m_CellScale = cellScale;
m_GridSize = gridSize;
m_DetectableTags = detectableTags;
CompressionType = compression;
if (m_GridSize.y != 1)
{
throw new UnityAgentsException("GridSensor only supports 2D grids.");
}
m_NumCells = m_GridSize.x * m_GridSize.z;
m_CellObservationSize = GetCellObservationSize();
m_ObservationSpec = ObservationSpec.Visual(m_GridSize.x, m_GridSize.z, m_CellObservationSize);
m_PerceptionTexture = new Texture2D(m_GridSize.x, m_GridSize.z, TextureFormat.RGB24, false);
ResetPerceptionBuffer();
}
///
/// The compression type used by the sensor.
///
public SensorCompressionType CompressionType
{
get { return m_CompressionType; }
set
{
if (!IsDataNormalized() && value == SensorCompressionType.PNG)
{
Debug.LogWarning($"Compression type {value} is only supported with normalized data. " +
"The sensor will not compress the data.");
return;
}
m_CompressionType = value;
}
}
internal float[] PerceptionBuffer
{
get { return m_PerceptionBuffer; }
}
///
/// The tags which the sensor dectects.
///
protected string[] DetectableTags
{
get { return m_DetectableTags; }
}
///
public void Reset() { }
///
/// Clears the perception buffer before loading in new data.
///
public void ResetPerceptionBuffer()
{
if (m_PerceptionBuffer != null)
{
Array.Clear(m_PerceptionBuffer, 0, m_PerceptionBuffer.Length);
Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
}
else
{
m_PerceptionBuffer = new float[m_CellObservationSize * m_NumCells];
m_CellDataBuffer = new float[m_CellObservationSize];
m_PerceptionColors = new Color[m_NumCells];
}
}
///
public string GetName()
{
return m_Name;
}
///
public CompressionSpec GetCompressionSpec()
{
return new CompressionSpec(CompressionType);
}
///
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.GridSensor;
}
///
public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation"))
{
var allBytes = new List();
var numImages = (m_CellObservationSize + 2) / 3;
for (int i = 0; i < numImages; i++)
{
var channelIndex = 3 * i;
GridValuesToTexture(channelIndex, Math.Min(3, m_CellObservationSize - channelIndex));
allBytes.AddRange(m_PerceptionTexture.EncodeToPNG());
}
return allBytes.ToArray();
}
}
///
/// Convert observation values to texture for PNG compression.
///
void GridValuesToTexture(int channelIndex, int numChannelsToAdd)
{
for (int i = 0; i < m_NumCells; i++)
{
for (int j = 0; j < numChannelsToAdd; j++)
{
m_PerceptionColors[i][j] = m_PerceptionBuffer[i * m_CellObservationSize + channelIndex + j];
}
}
m_PerceptionTexture.SetPixels(m_PerceptionColors);
}
///
/// Get the observation values of the detected game object.
/// Default is to record the detected tag index.
///
/// This method can be overridden to encode the observation differently or get custom data from the object.
/// When overriding this method, and
/// might also need to change accordingly.
///
/// The game object that was detected within a certain cell
/// The index of the detectedObject's tag in the DetectableObjects list
/// The buffer to write the observation values.
/// The buffer size is configured by .
///
///
/// Here is an example of overriding GetObjectData to get the velocity of a potential Rigidbody:
///
/// protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
/// {
/// if (tagIndex == Array.IndexOf(DetectableTags, "RigidBodyObject"))
/// {
/// Rigidbody rigidbody = detectedObject.GetComponent<Rigidbody>();
/// dataBuffer[0] = rigidbody.velocity.x;
/// dataBuffer[1] = rigidbody.velocity.y;
/// dataBuffer[2] = rigidbody.velocity.z;
/// }
/// }
///
///
protected virtual void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
{
dataBuffer[0] = tagIndex + 1;
}
///
/// Get the observation size for each cell. This will be the size of dataBuffer for .
/// If overriding , override this method as well to the custom observation size.
///
/// The observation size of each cell.
protected virtual int GetCellObservationSize()
{
return 1;
}
///
/// Whether the data is normalized within [0, 1]. The sensor can only use PNG compression if the data is normailzed.
/// If overriding , override this method as well according to the custom observation values.
///
/// Bool value indicating whether data is normalized.
protected virtual bool IsDataNormalized()
{
return false;
}
///
/// Whether to process all detected colliders in a cell. Default to false and only use the one closest to the agent.
/// If overriding , consider override this method when needed.
///
/// Bool value indicating whether to process all detected colliders in a cell.
protected internal virtual ProcessCollidersMethod GetProcessCollidersMethod()
{
return ProcessCollidersMethod.ProcessClosestColliders;
}
///
/// If using PNG compression, check if the values are normalized.
///
void ValidateValues(float[] dataValues, GameObject detectedObject)
{
if (m_CompressionType != SensorCompressionType.PNG)
{
return;
}
for (int j = 0; j < dataValues.Length; j++)
{
if (dataValues[j] < 0 || dataValues[j] > 1)
throw new UnityAgentsException($"When using compression type {m_CompressionType} the data value has to be normalized between 0-1. " +
$"Received value[{dataValues[j]}] for {detectedObject.name}");
}
}
///
/// Collect data from the detected object if a detectable tag is matched.
///
internal void ProcessDetectedObject(GameObject detectedObject, int cellIndex)
{
Profiler.BeginSample("GridSensor.ProcessDetectedObject");
for (var i = 0; i < m_DetectableTags.Length; i++)
{
if (!ReferenceEquals(detectedObject, null) && detectedObject.CompareTag(m_DetectableTags[i]))
{
if (GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders)
{
Array.Copy(m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellDataBuffer, 0, m_CellObservationSize);
}
else
{
Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
}
GetObjectData(detectedObject, i, m_CellDataBuffer);
ValidateValues(m_CellDataBuffer, detectedObject);
Array.Copy(m_CellDataBuffer, 0, m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize);
break;
}
}
Profiler.EndSample();
}
///
public void Update()
{
ResetPerceptionBuffer();
using (TimerStack.Instance.Scoped("GridSensor.Update"))
{
if (m_GridPerception != null)
{
m_GridPerception.Perceive();
}
}
}
///
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
///
public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("GridSensor.Write"))
{
int index = 0;
for (var h = m_GridSize.z - 1; h >= 0; h--)
{
for (var w = 0; w < m_GridSize.x; w++)
{
for (var d = 0; d < m_CellObservationSize; d++)
{
writer[h, w, d] = m_PerceptionBuffer[index];
index++;
}
}
}
return index;
}
}
///
/// Clean up the internal objects.
///
public void Dispose()
{
if (!ReferenceEquals(null, m_PerceptionTexture))
{
Utilities.DestroyTexture(m_PerceptionTexture);
m_PerceptionTexture = null;
}
}
}
}