using System;
using System.Collections.Generic;
using Barracuda;
using MLAgents.InferenceBrain;
namespace MLAgents.Sensor
/// Allows sensors to write to both TensorProxy and float arrays/lists.
public class WriteAdapter
IList m_Data;
int m_Offset;
TensorProxy m_Proxy;
int m_Batch;
TensorShape m_TensorShape;
/// Set the adapter to write to an IList at the given channelOffset.
/// Float array or list that will be written to.
/// Shape of the observations to be written.
/// Offset from the start of the float data to write to.
public void SetTarget(IList data, int[] shape, int offset)
m_Data = data;
m_Offset = offset;
m_Proxy = null;
m_Batch = 0;
if (shape.Length == 1)
m_TensorShape = new TensorShape(m_Batch, shape[0]);
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]);
/// Set the adapter to write to a TensorProxy at the given batch and channel offset.
/// Tensor proxy that will be writtent to.
/// Batch index in the tensor proxy (i.e. the index of the Agent)
/// Offset from the start of the channel to write to.
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
m_Proxy = tensorProxy;
m_Batch = batchIndex;
m_Offset = channelOffset;
m_Data = null;
m_TensorShape =;
/// 1D write access at a specified index. Use AddRange if possible instead.
/// Index to write to
public float this[int index]
if (m_Data != null)
m_Data[index + m_Offset] = value;
{[m_Batch, index + m_Offset] = value;
/// 3D write access at the specified height, width, and channel. Only usable with a TensorProxy target.
public float this[int h, int w, int ch]
if (m_Data != null)
if (h < 0 || h >= m_TensorShape.height)
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height-1}]");
if (w < 0 || w >= m_TensorShape.width)
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width-1}]");
if (ch < 0 || ch >= m_TensorShape.channels)
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels-1}]");
var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset);
m_Data[index] = value;
{[m_Batch, h, w, ch + m_Offset] = value;
/// Write the range of floats
/// Optional write offset
public void AddRange(IEnumerable data, int writeOffset = 0)
if (m_Data != null)
int index = 0;
foreach (var val in data)
m_Data[index + m_Offset + writeOffset] = val;
int index = 0;
foreach (var val in data)
{[m_Batch, index + m_Offset + writeOffset] = val;