using System.Collections.Generic;
using MLAgents.InferenceBrain;
namespace MLAgents.Sensor
{
///
/// Allows sensors to write to both TensorProxy and float arrays/lists.
///
public class WriteAdapter
{
IList m_Data;
int m_Offset;
TensorProxy m_Proxy;
int m_Batch;
///
/// Set the adapter to write to an IList at the given channelOffset.
///
///
///
public void SetTarget(IList data, int offset)
{
m_Data = data;
m_Offset = offset;
m_Proxy = null;
m_Batch = -1;
}
///
/// Set the adapter to write to a TensorProxy at the given batch and channel offset.
///
///
///
///
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
{
m_Proxy = tensorProxy;
m_Batch = batchIndex;
m_Offset = channelOffset;
m_Data = null;
}
///
/// 1D write access at a specified index. Use AddRange if possible instead.
///
/// Index to write to
public float this[int index]
{
set
{
if (m_Data != null)
{
m_Data[index + m_Offset] = value;
}
else
{
m_Proxy.data[m_Batch, index + m_Offset] = value;
}
}
}
///
/// 3D write access at the specified height, width, and channel. Only usable with a TensorProxy target.
///
///
///
///
public float this[int h, int w, int ch]
{
set
{
// Only TensorProxy supports 3D access
m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value;
}
}
///
/// Write the range of floats
///
///
/// Optional write offset
public void AddRange(IEnumerable data, int writeOffset = 0)
{
if (m_Data != null)
{
int index = 0;
foreach (var val in data)
{
m_Data[index + m_Offset + writeOffset] = val;
index++;
}
}
else
{
int index = 0;
foreach (var val in data)
{
m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val;
index++;
}
}
}
}
}