using System; using System.Collections.Generic; using System.Collections.ObjectModel; using UnityEngine; namespace Unity.MLAgents.Sensors { /// /// A sensor implementation for vector observations. /// public class VectorSensor : ISensor, IBuiltInSensor { // TODO use float[] instead // TODO allow setting float[] List m_Observations; private ObservationSpec m_ObservationSpec; string m_Name; /// /// Initializes the sensor. /// /// Number of vector observations. /// Name of the sensor. public VectorSensor(int observationSize, string name = null) { if (name == null) { name = $"VectorSensor_size{observationSize}"; } m_Observations = new List(observationSize); m_Name = name; m_ObservationSpec = ObservationSpec.FromShape(observationSize); } /// public int Write(ObservationWriter writer) { var expectedObservations = m_ObservationSpec.Shape[0]; if (m_Observations.Count > expectedObservations) { // Too many observations, truncate Debug.LogWarningFormat( "More observations ({0}) made than vector observation size ({1}). The observations will be truncated.", m_Observations.Count, expectedObservations ); m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations); } else if (m_Observations.Count < expectedObservations) { // Not enough observations; pad with zeros. Debug.LogWarningFormat( "Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.", m_Observations.Count, expectedObservations ); for (int i = m_Observations.Count; i < expectedObservations; i++) { m_Observations.Add(0); } } writer.AddList(m_Observations); return expectedObservations; } /// /// Returns a read-only view of the observations that added. /// /// A read-only view of the observations list. internal ReadOnlyCollection GetObservations() { return m_Observations.AsReadOnly(); } /// public void Update() { Clear(); } /// public void Reset() { Clear(); } /// public ObservationSpec GetObservationSpec() { return m_ObservationSpec; } /// public string GetName() { return m_Name; } /// public virtual byte[] GetCompressedObservation() { return null; } /// public virtual SensorCompressionType GetCompressionType() { return SensorCompressionType.None; } /// public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.VectorSensor; } void Clear() { m_Observations.Clear(); } void AddFloatObs(float obs) { #if DEBUG Utilities.DebugCheckNanAndInfinity(obs, nameof(obs), nameof(AddFloatObs)); #endif m_Observations.Add(obs); } // Compatibility methods with Agent observation. These should be removed eventually. /// /// Adds a float observation to the vector observations of the agent. /// /// Observation. public void AddObservation(float observation) { AddFloatObs(observation); } /// /// Adds an integer observation to the vector observations of the agent. /// /// Observation. public void AddObservation(int observation) { AddFloatObs(observation); } /// /// Adds an Vector3 observation to the vector observations of the agent. /// /// Observation. public void AddObservation(Vector3 observation) { AddFloatObs(observation.x); AddFloatObs(observation.y); AddFloatObs(observation.z); } /// /// Adds an Vector2 observation to the vector observations of the agent. /// /// Observation. public void AddObservation(Vector2 observation) { AddFloatObs(observation.x); AddFloatObs(observation.y); } /// /// Adds a list or array of float observations to the vector observations of the agent. /// /// Observation. public void AddObservation(IList observation) { for (var i = 0; i < observation.Count; i++) { AddFloatObs(observation[i]); } } /// /// Adds a quaternion observation to the vector observations of the agent. /// /// Observation. public void AddObservation(Quaternion observation) { AddFloatObs(observation.x); AddFloatObs(observation.y); AddFloatObs(observation.z); AddFloatObs(observation.w); } /// /// Adds a boolean observation to the vector observation of the agent. /// /// Observation. public void AddObservation(bool observation) { AddFloatObs(observation ? 1f : 0f); } /// /// Adds a one-hot encoding observation. /// /// The index of this observation. /// The max index for any observation. public void AddOneHotObservation(int observation, int range) { for (var i = 0; i < range; i++) { AddFloatObs(i == observation ? 1.0f : 0.0f); } } } }