Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

216 行
6.7 KiB

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using UnityEngine;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A sensor implementation for vector observations.
/// </summary>
public class VectorSensor : ISensor, IBuiltInSensor
{
// TODO use float[] instead
// TODO allow setting float[]
List<float> m_Observations;
ObservationSpec m_ObservationSpec;
string m_Name;
/// <summary>
/// Initializes the sensor.
/// </summary>
/// <param name="observationSize">Number of vector observations.</param>
/// <param name="name">Name of the sensor.</param>
public VectorSensor(int observationSize, string name = null)
{
if (name == null)
{
name = $"VectorSensor_size{observationSize}";
}
m_Observations = new List<float>(observationSize);
m_Name = name;
m_ObservationSpec = ObservationSpec.Vector(observationSize);
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
var expectedObservations = m_ObservationSpec.Shape[0];
if (m_Observations.Count > expectedObservations)
{
// Too many observations, truncate
Debug.LogWarningFormat(
"More observations ({0}) made than vector observation size ({1}). The observations will be truncated.",
m_Observations.Count, expectedObservations
);
m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations);
}
else if (m_Observations.Count < expectedObservations)
{
// Not enough observations; pad with zeros.
Debug.LogWarningFormat(
"Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.",
m_Observations.Count, expectedObservations
);
for (int i = m_Observations.Count; i < expectedObservations; i++)
{
m_Observations.Add(0);
}
}
writer.AddList(m_Observations);
return expectedObservations;
}
/// <summary>
/// Returns a read-only view of the observations that added.
/// </summary>
/// <returns>A read-only view of the observations list.</returns>
internal ReadOnlyCollection<float> GetObservations()
{
return m_Observations.AsReadOnly();
}
/// <inheritdoc/>
public void Update()
{
Clear();
}
/// <inheritdoc/>
public void Reset()
{
Clear();
}
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
/// <inheritdoc/>
public string GetName()
{
return m_Name;
}
/// <inheritdoc/>
public virtual byte[] GetCompressedObservation()
{
return null;
}
/// <inheritdoc/>
public virtual SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.VectorSensor;
}
void Clear()
{
m_Observations.Clear();
}
void AddFloatObs(float obs)
{
#if DEBUG
Utilities.DebugCheckNanAndInfinity(obs, nameof(obs), nameof(AddFloatObs));
#endif
m_Observations.Add(obs);
}
// Compatibility methods with Agent observation. These should be removed eventually.
/// <summary>
/// Adds a float observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(float observation)
{
AddFloatObs(observation);
}
/// <summary>
/// Adds an integer observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(int observation)
{
AddFloatObs(observation);
}
/// <summary>
/// Adds an Vector3 observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(Vector3 observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
AddFloatObs(observation.z);
}
/// <summary>
/// Adds an Vector2 observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(Vector2 observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
}
/// <summary>
/// Adds a list or array of float observations to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(IList<float> observation)
{
for (var i = 0; i < observation.Count; i++)
{
AddFloatObs(observation[i]);
}
}
/// <summary>
/// Adds a quaternion observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(Quaternion observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
AddFloatObs(observation.z);
AddFloatObs(observation.w);
}
/// <summary>
/// Adds a boolean observation to the vector observation of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(bool observation)
{
AddFloatObs(observation ? 1f : 0f);
}
/// <summary>
/// Adds a one-hot encoding observation.
/// </summary>
/// <param name="observation">The index of this observation.</param>
/// <param name="range">The max index for any observation.</param>
public void AddOneHotObservation(int observation, int range)
{
for (var i = 0; i < range; i++)
{
AddFloatObs(i == observation ? 1.0f : 0.0f);
}
}
}
}