using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using UnityEngine;
namespace Unity.MLAgents.Sensors
{
///
/// A sensor implementation for vector observations.
///
public class VectorSensor : ISensor, IBuiltInSensor
{
// TODO use float[] instead
// TODO allow setting float[]
List m_Observations;
private ObservationSpec m_ObservationSpec;
string m_Name;
///
/// Initializes the sensor.
///
/// Number of vector observations.
/// Name of the sensor.
public VectorSensor(int observationSize, string name = null)
{
if (name == null)
{
name = $"VectorSensor_size{observationSize}";
}
m_Observations = new List(observationSize);
m_Name = name;
m_ObservationSpec = ObservationSpec.FromShape(observationSize);
}
///
public int Write(ObservationWriter writer)
{
var expectedObservations = m_ObservationSpec.Shape[0];
if (m_Observations.Count > expectedObservations)
{
// Too many observations, truncate
Debug.LogWarningFormat(
"More observations ({0}) made than vector observation size ({1}). The observations will be truncated.",
m_Observations.Count, expectedObservations
);
m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations);
}
else if (m_Observations.Count < expectedObservations)
{
// Not enough observations; pad with zeros.
Debug.LogWarningFormat(
"Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.",
m_Observations.Count, expectedObservations
);
for (int i = m_Observations.Count; i < expectedObservations; i++)
{
m_Observations.Add(0);
}
}
writer.AddList(m_Observations);
return expectedObservations;
}
///
/// Returns a read-only view of the observations that added.
///
/// A read-only view of the observations list.
internal ReadOnlyCollection GetObservations()
{
return m_Observations.AsReadOnly();
}
///
public void Update()
{
Clear();
}
///
public void Reset()
{
Clear();
}
///
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
///
public string GetName()
{
return m_Name;
}
///
public virtual byte[] GetCompressedObservation()
{
return null;
}
///
public virtual SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
///
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.VectorSensor;
}
void Clear()
{
m_Observations.Clear();
}
void AddFloatObs(float obs)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(obs, nameof(obs), nameof(AddFloatObs));
#endif
m_Observations.Add(obs);
}
// Compatibility methods with Agent observation. These should be removed eventually.
///
/// Adds a float observation to the vector observations of the agent.
///
/// Observation.
public void AddObservation(float observation)
{
AddFloatObs(observation);
}
///
/// Adds an integer observation to the vector observations of the agent.
///
/// Observation.
public void AddObservation(int observation)
{
AddFloatObs(observation);
}
///
/// Adds an Vector3 observation to the vector observations of the agent.
///
/// Observation.
public void AddObservation(Vector3 observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
AddFloatObs(observation.z);
}
///
/// Adds an Vector2 observation to the vector observations of the agent.
///
/// Observation.
public void AddObservation(Vector2 observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
}
///
/// Adds a list or array of float observations to the vector observations of the agent.
///
/// Observation.
public void AddObservation(IList observation)
{
for (var i = 0; i < observation.Count; i++)
{
AddFloatObs(observation[i]);
}
}
///
/// Adds a quaternion observation to the vector observations of the agent.
///
/// Observation.
public void AddObservation(Quaternion observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
AddFloatObs(observation.z);
AddFloatObs(observation.w);
}
///
/// Adds a boolean observation to the vector observation of the agent.
///
/// Observation.
public void AddObservation(bool observation)
{
AddFloatObs(observation ? 1f : 0f);
}
///
/// Adds a one-hot encoding observation.
///
/// The index of this observation.
/// The max index for any observation.
public void AddOneHotObservation(int observation, int range)
{
for (var i = 0; i < range; i++)
{
AddFloatObs(i == observation ? 1.0f : 0.0f);
}
}
}
}