using System; using System.Collections.Generic; using Unity.Barracuda; using Unity.MLAgents.Inference; using UnityEngine; namespace Unity.MLAgents.Sensors { /// /// Allows sensors to write to both TensorProxy and float arrays/lists. /// public class ObservationWriter { IList m_Data; int m_Offset; TensorProxy m_Proxy; int m_Batch; TensorShape m_TensorShape; internal ObservationWriter() {} /// /// Set the writer to write to an IList at the given channelOffset. /// /// Float array or list that will be written to. /// Shape of the observations to be written. /// Offset from the start of the float data to write to. internal void SetTarget(IList data, int[] shape, int offset) { m_Data = data; m_Offset = offset; m_Proxy = null; m_Batch = 0; if (shape.Length == 1) { m_TensorShape = new TensorShape(m_Batch, shape[0]); } else { m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); } } /// /// Set the writer to write to a TensorProxy at the given batch and channel offset. /// /// Tensor proxy that will be written to. /// Batch index in the tensor proxy (i.e. the index of the Agent). /// Offset from the start of the channel to write to. internal void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset) { m_Proxy = tensorProxy; m_Batch = batchIndex; m_Offset = channelOffset; m_Data = null; m_TensorShape = m_Proxy.data.shape; } /// /// 1D write access at a specified index. Use AddRange if possible instead. /// /// Index to write to. public float this[int index] { set { if (m_Data != null) { m_Data[index + m_Offset] = value; } else { m_Proxy.data[m_Batch, index + m_Offset] = value; } } } /// /// 3D write access at the specified height, width, and channel. /// /// /// /// public float this[int h, int w, int ch] { set { if (m_Data != null) { if (h < 0 || h >= m_TensorShape.height) { throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height - 1}]"); } if (w < 0 || w >= m_TensorShape.width) { throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width - 1}]"); } if (ch < 0 || ch >= m_TensorShape.channels) { throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels - 1}]"); } var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset); m_Data[index] = value; } else { m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value; } } } /// /// Write the range of floats /// /// /// Optional write offset. public void AddRange(IEnumerable data, int writeOffset = 0) { if (m_Data != null) { int index = 0; foreach (var val in data) { m_Data[index + m_Offset + writeOffset] = val; index++; } } else { int index = 0; foreach (var val in data) { m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val; index++; } } } /// /// Write the Vector3 components. /// /// The Vector3 to be written. /// Optional write offset. public void Add(Vector3 vec, int writeOffset = 0) { if (m_Data != null) { m_Data[m_Offset + writeOffset + 0] = vec.x; m_Data[m_Offset + writeOffset + 1] = vec.y; m_Data[m_Offset + writeOffset + 2] = vec.z; } else { m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; } } /// /// Write the Vector4 components. /// /// The Vector4 to be written. /// Optional write offset. public void Add(Vector4 vec, int writeOffset = 0) { if (m_Data != null) { m_Data[m_Offset + writeOffset + 0] = vec.x; m_Data[m_Offset + writeOffset + 1] = vec.y; m_Data[m_Offset + writeOffset + 2] = vec.z; m_Data[m_Offset + writeOffset + 3] = vec.w; } else { m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = vec.w; } } /// /// Write the Quaternion components. /// /// The Quaternion to be written. /// Optional write offset. public void Add(Quaternion quat, int writeOffset = 0) { if (m_Data != null) { m_Data[m_Offset + writeOffset + 0] = quat.x; m_Data[m_Offset + writeOffset + 1] = quat.y; m_Data[m_Offset + writeOffset + 2] = quat.z; m_Data[m_Offset + writeOffset + 3] = quat.w; } else { m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = quat.x; m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = quat.y; m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = quat.z; m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = quat.w; } } } }