using System; using System.Collections.Generic; using Unity.Barracuda; using Unity.MLAgents.Inference; using UnityEngine; namespace Unity.MLAgents.Sensors { /// /// Allows sensors to write to both TensorProxy and float arrays/lists. /// public class ObservationWriter { IList m_Data; int m_Offset; TensorProxy m_Proxy; int m_Batch; TensorShape m_TensorShape; internal ObservationWriter() { } /// /// Set the writer to write to an IList at the given channelOffset. /// /// Float array or list that will be written to. /// ObservationSpec of the observation to be written /// Offset from the start of the float data to write to. internal void SetTarget(IList data, ObservationSpec observationSpec, int offset) { // TODO remove int[] version SetTarget(data, observationSpec.Shape, offset); } /// /// Set the writer to write to an IList at the given channelOffset. /// /// Float array or list that will be written to. /// Shape of the observations to be written. /// Offset from the start of the float data to write to. internal void SetTarget(IList data, InplaceArray shape, int offset) { m_Data = data; m_Offset = offset; m_Proxy = null; m_Batch = 0; if (shape.Length == 1) { m_TensorShape = new TensorShape(m_Batch, shape[0]); } else if (shape.Length == 2) { m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] }); } else { m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); } } /// /// Set the writer to write to a TensorProxy at the given batch and channel offset. /// /// Tensor proxy that will be written to. /// Batch index in the tensor proxy (i.e. the index of the Agent). /// Offset from the start of the channel to write to. internal void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset) { m_Proxy = tensorProxy; m_Batch = batchIndex; m_Offset = channelOffset; m_Data = null; m_TensorShape = m_Proxy.data.shape; } /// /// 1D write access at a specified index. Use AddList if possible instead. /// /// Index to write to. public float this[int index] { set { if (m_Data != null) { m_Data[index + m_Offset] = value; } else { m_Proxy.data[m_Batch, index + m_Offset] = value; } } } /// /// 3D write access at the specified height, width, and channel. /// /// /// /// public float this[int h, int w, int ch] { set { if (m_Data != null) { if (h < 0 || h >= m_TensorShape.height) { throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height - 1}]"); } if (w < 0 || w >= m_TensorShape.width) { throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width - 1}]"); } if (ch < 0 || ch >= m_TensorShape.channels) { throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels - 1}]"); } var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset); m_Data[index] = value; } else { m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value; } } } /// /// Write the list of floats. /// /// The actual list of floats to write. /// Optional write offset to start writing from. public void AddList(IList data, int writeOffset = 0) { if (m_Data != null) { for (var index = 0; index < data.Count; index++) { var val = data[index]; m_Data[index + m_Offset + writeOffset] = val; } } else { for (var index = 0; index < data.Count; index++) { var val = data[index]; m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val; } } } /// /// Write the Vector3 components. /// /// The Vector3 to be written. /// Optional write offset. public void Add(Vector3 vec, int writeOffset = 0) { if (m_Data != null) { m_Data[m_Offset + writeOffset + 0] = vec.x; m_Data[m_Offset + writeOffset + 1] = vec.y; m_Data[m_Offset + writeOffset + 2] = vec.z; } else { m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; } } /// /// Write the Vector4 components. /// /// The Vector4 to be written. /// Optional write offset. public void Add(Vector4 vec, int writeOffset = 0) { if (m_Data != null) { m_Data[m_Offset + writeOffset + 0] = vec.x; m_Data[m_Offset + writeOffset + 1] = vec.y; m_Data[m_Offset + writeOffset + 2] = vec.z; m_Data[m_Offset + writeOffset + 3] = vec.w; } else { m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = vec.w; } } /// /// Write the Quaternion components. /// /// The Quaternion to be written. /// Optional write offset. public void Add(Quaternion quat, int writeOffset = 0) { if (m_Data != null) { m_Data[m_Offset + writeOffset + 0] = quat.x; m_Data[m_Offset + writeOffset + 1] = quat.y; m_Data[m_Offset + writeOffset + 2] = quat.z; m_Data[m_Offset + writeOffset + 3] = quat.w; } else { m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = quat.x; m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = quat.y; m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = quat.z; m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = quat.w; } } } /// /// Provides extension methods for the ObservationWriter. /// public static class ObservationWriterExtension { /// /// Writes a Texture2D into a ObservationWriter. /// /// /// Writer to fill with Texture data. /// /// /// The texture to be put into the tensor. /// /// /// If set to true the textures will be converted to grayscale before /// being stored in the tensor. /// /// The number of floats written public static int WriteTexture( this ObservationWriter obsWriter, Texture2D texture, bool grayScale) { var width = texture.width; var height = texture.height; var texturePixels = texture.GetPixels32(); // During training, we convert from Texture to PNG before sending to the trainer, which has the // effect of flipping the image. We need another flip here at inference time to match this. for (var h = height - 1; h >= 0; h--) { for (var w = 0; w < width; w++) { var currentPixel = texturePixels[(height - h - 1) * width + w]; if (grayScale) { obsWriter[h, w, 0] = (currentPixel.r + currentPixel.g + currentPixel.b) / 3f / 255.0f; } else { // For Color32, the r, g and b values are between 0 and 255. obsWriter[h, w, 0] = currentPixel.r / 255.0f; obsWriter[h, w, 1] = currentPixel.g / 255.0f; obsWriter[h, w, 2] = currentPixel.b / 255.0f; } } } return height * width * (grayScale ? 1 : 3); } } }