Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

172 行
5.2 KiB

using System.Collections.Generic;
using UnityEngine;
namespace MLAgents.Sensor
{
public class VectorSensor : ISensor
{
// TODO use float[] instead
// TOOD allow setting float[]
List<float> m_Observations;
int[] m_Shape;
string m_Name;
public VectorSensor(int observationSize, string name = null)
{
if (name == null)
{
name = $"VectorSensor_size{observationSize}";
}
m_Observations = new List<float>(observationSize);
m_Name = name;
m_Shape = new[] { observationSize };
}
public int Write(WriteAdapter adapter)
{
var expectedObservations = m_Shape[0];
if (m_Observations.Count > expectedObservations)
{
// Too many observations, truncate
Debug.LogWarningFormat(
"More observations ({0}) made than vector observation size ({1}). The observations will be truncated.",
m_Observations.Count, expectedObservations
);
m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations);
}
else if (m_Observations.Count < expectedObservations)
{
// Not enough observations; pad with zeros.
Debug.LogWarningFormat(
"Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.",
m_Observations.Count, expectedObservations
);
for (int i = m_Observations.Count; i < expectedObservations; i++)
{
m_Observations.Add(0);
}
}
adapter.AddRange(m_Observations);
return expectedObservations;
}
public void Update()
{
Clear();
}
public int[] GetFloatObservationShape()
{
return m_Shape;
}
public string GetName()
{
return m_Name;
}
public virtual byte[] GetCompressedObservation()
{
return null;
}
public virtual SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
void Clear()
{
m_Observations.Clear();
}
void AddFloatObs(float obs)
{
m_Observations.Add(obs);
}
// Compatibility methods with Agent observation. These should be removed eventually.
/// <summary>
/// Adds a float observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(float observation)
{
AddFloatObs(observation);
}
/// <summary>
/// Adds an integer observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(int observation)
{
AddFloatObs(observation);
}
/// <summary>
/// Adds an Vector3 observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(Vector3 observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
AddFloatObs(observation.z);
}
/// <summary>
/// Adds an Vector2 observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(Vector2 observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
}
/// <summary>
/// Adds a collection of float observations to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(IEnumerable<float> observation)
{
foreach (var f in observation)
{
AddFloatObs(f);
}
}
/// <summary>
/// Adds a quaternion observation to the vector observations of the agent.
/// </summary>
/// <param name="observation">Observation.</param>
public void AddObservation(Quaternion observation)
{
AddFloatObs(observation.x);
AddFloatObs(observation.y);
AddFloatObs(observation.z);
AddFloatObs(observation.w);
}
/// <summary>
/// Adds a boolean observation to the vector observation of the agent.
/// </summary>
/// <param name="observation"></param>
public void AddObservation(bool observation)
{
AddFloatObs(observation ? 1f : 0f);
}
public void AddOneHotObservation(int observation, int range)
{
for (var i = 0; i < range; i++)
{
AddFloatObs(i == observation ? 1.0f : 0.0f);
}
}
}
}