|
|
|
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using Barracuda; |
|
|
|
using MLAgents.InferenceBrain; |
|
|
|
|
|
|
|
namespace MLAgents.Sensor |
|
|
|
|
|
|
TensorProxy m_Proxy; |
|
|
|
int m_Batch; |
|
|
|
|
|
|
|
int[] m_Shape; |
|
|
|
TensorShape m_TensorShape; |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Set the adapter to write to an IList at the given channelOffset.
|
|
|
|
|
|
|
m_Offset = offset; |
|
|
|
m_Proxy = null; |
|
|
|
m_Batch = 0; |
|
|
|
m_Shape = shape; |
|
|
|
|
|
|
|
if (shape.Length == 1) |
|
|
|
{ |
|
|
|
m_TensorShape = new TensorShape(m_Batch, shape[0]); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
/// <param name="shape">Shape of the observations to be written.</param>
|
|
|
|
public void SetTarget(TensorProxy tensorProxy, int[] shape, int batchIndex, int channelOffset) |
|
|
|
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset) |
|
|
|
m_Shape = shape; |
|
|
|
m_TensorShape = m_Proxy.data.shape; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
{ |
|
|
|
set |
|
|
|
{ |
|
|
|
// TODO check shape is 1D?
|
|
|
|
if (m_Data != null) |
|
|
|
{ |
|
|
|
m_Data[index + m_Offset] = value; |
|
|
|
|
|
|
{ |
|
|
|
if (m_Data != null) |
|
|
|
{ |
|
|
|
var height = m_Shape[0]; |
|
|
|
var width = m_Shape[1]; |
|
|
|
var channels = m_Shape[2]; |
|
|
|
|
|
|
|
if (h < 0 || h >= height) |
|
|
|
if (h < 0 || h >= m_TensorShape.height) |
|
|
|
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {height-1}]"); |
|
|
|
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height-1}]"); |
|
|
|
if (w < 0 || w >= width) |
|
|
|
if (w < 0 || w >= m_TensorShape.width) |
|
|
|
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {width-1}]"); |
|
|
|
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width-1}]"); |
|
|
|
if (ch < 0 || ch >= channels) |
|
|
|
if (ch < 0 || ch >= m_TensorShape.channels) |
|
|
|
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {channels-1}]"); |
|
|
|
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels-1}]"); |
|
|
|
// Math copied from TensorShape.Index(). Note that m_Batch should always be 0
|
|
|
|
var index = m_Batch * height * width * channels + h * width * channels + w * channels + ch; |
|
|
|
m_Data[index + m_Offset] = value; |
|
|
|
var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset); |
|
|
|
m_Data[index] = value; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|