using System;
using System.Linq;
using System.Runtime.CompilerServices;
using UnityEngine;
using Unity.Barracuda;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
namespace Unity.MLAgents.Sensors
{
///
/// Sensor that wraps around another Sensor to provide temporal stacking.
/// Conceptually, consecutive observations are stored left-to-right, which is how they're output
/// For example, 4 stacked sets of observations would be output like
/// | t = now - 3 | t = now -3 | t = now - 2 | t = now |
/// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
/// Currently, observations are stacked on the last dimension.
///
public class StackingSensor : ISparseChannelSensor
{
///
/// The wrapped sensor.
///
ISensor m_WrappedSensor;
///
/// Number of stacks to save
///
int m_NumStackedObservations;
int m_UnstackedObservationSize;
string m_Name;
int[] m_Shape;
int[] m_WrappedShape;
///
/// Buffer of previous observations
///
float[][] m_StackedObservations;
byte[][] m_StackedCompressedObservations;
int m_CurrentIndex;
ObservationWriter m_LocalWriter = new ObservationWriter();
byte[] m_EmptyCompressedObservation;
int[] m_CompressionMapping;
TensorShape m_tensorShape;
///
/// Initializes the sensor.
///
/// The wrapped sensor.
/// Number of stacked observations to keep.
public StackingSensor(ISensor wrapped, int numStackedObservations)
{
// TODO ensure numStackedObservations > 1
m_WrappedSensor = wrapped;
m_NumStackedObservations = numStackedObservations;
m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
m_WrappedShape = wrapped.GetObservationShape();
m_Shape = new int[m_WrappedShape.Length];
m_UnstackedObservationSize = wrapped.ObservationSize();
for (int d = 0; d < m_WrappedShape.Length; d++)
{
m_Shape[d] = m_WrappedShape[d];
}
// TODO support arbitrary stacking dimension
m_Shape[m_Shape.Length - 1] *= numStackedObservations;
// Initialize uncompressed buffer anyway in case python trainer does not
// support the compression mapping and has to fall back to uncompressed obs.
m_StackedObservations = new float[numStackedObservations][];
for (var i = 0; i < numStackedObservations; i++)
{
m_StackedObservations[i] = new float[m_UnstackedObservationSize];
}
if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
{
m_StackedCompressedObservations = new byte[numStackedObservations][];
m_EmptyCompressedObservation = CreateEmptyPNG();
for (var i = 0; i < numStackedObservations; i++)
{
m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
}
m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
}
if (m_Shape.Length != 1)
{
m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
}
}
///
public int Write(ObservationWriter writer)
{
// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], m_WrappedShape, 0);
m_WrappedSensor.Write(m_LocalWriter);
// Now write the saved observations (oldest first)
var numWritten = 0;
if (m_WrappedShape.Length == 1)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
writer.AddRange(m_StackedObservations[obsIndex], numWritten);
numWritten += m_UnstackedObservationSize;
}
}
else
{
for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
for (var h = 0; h < m_WrappedShape[0]; h++)
{
for (var w = 0; w < m_WrappedShape[1]; w++)
{
for (var c = 0; c < m_WrappedShape[2]; c++)
{
writer[h, w, i * m_WrappedShape[2] + c] = m_StackedObservations[obsIndex][m_tensorShape.Index(0, h, w, c)];
}
}
}
}
numWritten = m_WrappedShape[0] * m_WrappedShape[1] * m_WrappedShape[2] * m_NumStackedObservations;
}
return numWritten;
}
///
/// Updates the index of the "current" buffer.
///
public void Update()
{
m_WrappedSensor.Update();
m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations;
}
///
public void Reset()
{
m_WrappedSensor.Reset();
// Zero out the buffer.
for (var i = 0; i < m_NumStackedObservations; i++)
{
Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length);
}
if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{
m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
}
}
}
///
public int[] GetObservationShape()
{
return m_Shape;
}
///
public string GetName()
{
return m_Name;
}
///
public byte[] GetCompressedObservation()
{
var compressed = m_WrappedSensor.GetCompressedObservation();
m_StackedCompressedObservations[m_CurrentIndex] = compressed;
int bytesLength = 0;
foreach (byte[] compressedObs in m_StackedCompressedObservations)
{
bytesLength += compressedObs.Length;
}
byte[] outputBytes = new byte[bytesLength];
int offset = 0;
for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
Buffer.BlockCopy(m_StackedCompressedObservations[obsIndex],
0, outputBytes, offset, m_StackedCompressedObservations[obsIndex].Length);
offset += m_StackedCompressedObservations[obsIndex].Length;
}
return outputBytes;
}
public int[] GetCompressedChannelMapping()
{
return m_CompressionMapping;
}
///
public SensorCompressionType GetCompressionType()
{
return m_WrappedSensor.GetCompressionType();
}
///
/// Create Empty PNG for initializing the buffer for stacking.
///
internal byte[] CreateEmptyPNG()
{
int height = m_WrappedSensor.GetObservationShape()[0];
int width = m_WrappedSensor.GetObservationShape()[1];
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
return texture2D.EncodeToPNG();
}
///
/// Constrct stacked CompressedChannelMapping.
///
internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor)
{
// Get CompressedChannelMapping of the wrapped sensor. If the
// wrapped sensor doesn't have one, use default mapping.
// Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
int[] wrappedMapping = null;
int wrappedNumChannel = wrappedSenesor.GetObservationShape()[2];
var sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;
if (sparseChannelSensor != null)
{
wrappedMapping = sparseChannelSensor.GetCompressedChannelMapping();
}
if (wrappedMapping == null)
{
if (wrappedNumChannel == 1)
{
wrappedMapping = new int[] { 0, 0, 0 };
}
else
{
wrappedMapping = Enumerable.Range(0, wrappedNumChannel).ToArray();
}
}
// Construct stacked mapping using the mapping of wrapped sensor.
// First pad the wrapped mapping to multiple of 3, then repeat
// and add offset to each copy to form the stacked mapping.
int paddedMapLength = (wrappedMapping.Length + 2) / 3 * 3;
var compressionMapping = new int[paddedMapLength * m_NumStackedObservations];
for (var i = 0; i < m_NumStackedObservations; i++)
{
var offset = wrappedNumChannel * i;
for (var j = 0; j < paddedMapLength; j++)
{
if (j < wrappedMapping.Length)
{
compressionMapping[j + paddedMapLength * i] = wrappedMapping[j] >= 0 ? wrappedMapping[j] + offset : -1;
}
else
{
compressionMapping[j + paddedMapLength * i] = -1;
}
}
}
return compressionMapping;
}
}
}