using System.Collections.Generic; using UnityEngine; namespace MLAgents.Sensors { /// /// A sensor implementation for vector observations. /// public class VectorSensor : ISensor { // TODO use float[] instead // TODO allow setting float[] List m_Observations; int[] m_Shape; string m_Name; /// /// Initializes the sensor. /// /// Number of vector observations. /// Name of the sensor. public VectorSensor(int observationSize, string name = null) { if (name == null) { name = $"VectorSensor_size{observationSize}"; } m_Observations = new List(observationSize); m_Name = name; m_Shape = new[] { observationSize }; } /// public int Write(WriteAdapter adapter) { var expectedObservations = m_Shape[0]; if (m_Observations.Count > expectedObservations) { // Too many observations, truncate Debug.LogWarningFormat( "More observations ({0}) made than vector observation size ({1}). The observations will be truncated.", m_Observations.Count, expectedObservations ); m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations); } else if (m_Observations.Count < expectedObservations) { // Not enough observations; pad with zeros. Debug.LogWarningFormat( "Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.", m_Observations.Count, expectedObservations ); for (int i = m_Observations.Count; i < expectedObservations; i++) { m_Observations.Add(0); } } adapter.AddRange(m_Observations); return expectedObservations; } /// public void Update() { Clear(); } /// public void Reset() { Clear(); } /// public int[] GetObservationShape() { return m_Shape; } /// public string GetName() { return m_Name; } /// public virtual byte[] GetCompressedObservation() { return null; } /// public virtual SensorCompressionType GetCompressionType() { return SensorCompressionType.None; } void Clear() { m_Observations.Clear(); } void AddFloatObs(float obs) { #if DEBUG Utilities.DebugCheckNanAndInfinity(obs, nameof(obs), nameof(AddFloatObs)); #endif m_Observations.Add(obs); } // Compatibility methods with Agent observation. These should be removed eventually. /// /// Adds a float observation to the vector observations of the agent. /// /// Observation. public void AddObservation(float observation) { AddFloatObs(observation); } /// /// Adds an integer observation to the vector observations of the agent. /// /// Observation. public void AddObservation(int observation) { AddFloatObs(observation); } /// /// Adds an Vector3 observation to the vector observations of the agent. /// /// Observation. public void AddObservation(Vector3 observation) { AddFloatObs(observation.x); AddFloatObs(observation.y); AddFloatObs(observation.z); } /// /// Adds an Vector2 observation to the vector observations of the agent. /// /// Observation. public void AddObservation(Vector2 observation) { AddFloatObs(observation.x); AddFloatObs(observation.y); } /// /// Adds a collection of float observations to the vector observations of the agent. /// /// Observation. public void AddObservation(IEnumerable observation) { foreach (var f in observation) { AddFloatObs(f); } } /// /// Adds a quaternion observation to the vector observations of the agent. /// /// Observation. public void AddObservation(Quaternion observation) { AddFloatObs(observation.x); AddFloatObs(observation.y); AddFloatObs(observation.z); AddFloatObs(observation.w); } /// /// Adds a boolean observation to the vector observation of the agent. /// /// Observation. public void AddObservation(bool observation) { AddFloatObs(observation ? 1f : 0f); } /// /// Adds a one-hot encoding observation. /// /// The index of this observation. /// The max index for any observation. public void AddOneHotObservation(int observation, int range) { for (var i = 0; i < range; i++) { AddFloatObs(i == observation ? 1.0f : 0.0f); } } } }