Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

348 行
12 KiB

using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Profiling;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// The way the GridSensor process detected colliders in a cell.
/// </summary>
public enum ProcessCollidersMethod
{
/// <summary>
/// Get data from all colliders detected in a cell
/// </summary>
ProcessAllColliders,
/// <summary>
/// Get data from the collider closest to the agent
/// </summary>
ProcessClosestColliders
}
/// <summary>
/// Grid-based sensor.
/// </summary>
public class GridSensorBase : ISensor, IBuiltInSensor, IDisposable
{
string m_Name;
Vector3 m_CellScale;
Vector3Int m_GridSize;
string[] m_DetectableTags;
SensorCompressionType m_CompressionType;
ObservationSpec m_ObservationSpec;
internal BoxOverlapChecker m_BoxOverlapChecker;
// Buffers
float[] m_PerceptionBuffer;
Color[] m_PerceptionColors;
Texture2D m_PerceptionTexture;
float[] m_CellDataBuffer;
// Utility Constants Calculated on Init
int m_NumCells;
int m_CellObservationSize;
Vector3 m_CellCenterOffset;
/// <summary>
/// Create a GridSensorBase with the specified configuration.
/// </summary>
/// <param name="name">The sensor name</param>
/// <param name="cellScale">The scale of each cell in the grid</param>
/// <param name="gridSize">Number of cells on each side of the grid</param>
/// <param name="detectableTags">Tags to be detected by the sensor</param>
/// <param name="compression">Compression type</param>
public GridSensorBase(
string name,
Vector3 cellScale,
Vector3Int gridSize,
string[] detectableTags,
SensorCompressionType compression
)
{
m_Name = name;
m_CellScale = cellScale;
m_GridSize = gridSize;
m_DetectableTags = detectableTags;
CompressionType = compression;
if (m_GridSize.y != 1)
{
throw new UnityAgentsException("GridSensor only supports 2D grids.");
}
m_NumCells = m_GridSize.x * m_GridSize.z;
m_CellObservationSize = GetCellObservationSize();
m_ObservationSpec = ObservationSpec.Visual(m_GridSize.x, m_GridSize.z, m_CellObservationSize);
m_PerceptionTexture = new Texture2D(m_GridSize.x, m_GridSize.z, TextureFormat.RGB24, false);
ResetPerceptionBuffer();
}
/// <summary>
/// The compression type used by the sensor.
/// </summary>
public SensorCompressionType CompressionType
{
get { return m_CompressionType; }
set
{
if (!IsDataNormalized() && value == SensorCompressionType.PNG)
{
Debug.LogWarning($"Compression type {value} is only supported with normalized data. " +
"The sensor will not compress the data.");
return;
}
m_CompressionType = value;
}
}
internal float[] PerceptionBuffer
{
get { return m_PerceptionBuffer; }
}
/// <summary>
/// The tags which the sensor dectects.
/// </summary>
protected string[] DetectableTags
{
get { return m_DetectableTags; }
}
/// <inheritdoc/>
public void Reset() { }
/// <summary>
/// Clears the perception buffer before loading in new data.
/// </summary>
public void ResetPerceptionBuffer()
{
if (m_PerceptionBuffer != null)
{
Array.Clear(m_PerceptionBuffer, 0, m_PerceptionBuffer.Length);
Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
}
else
{
m_PerceptionBuffer = new float[m_CellObservationSize * m_NumCells];
m_CellDataBuffer = new float[m_CellObservationSize];
m_PerceptionColors = new Color[m_NumCells];
}
}
/// <inheritdoc/>
public string GetName()
{
return m_Name;
}
/// <inheritdoc/>
public CompressionSpec GetCompressionSpec()
{
return new CompressionSpec(CompressionType);
}
/// <inheritdoc/>
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.GridSensor;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation"))
{
var allBytes = new List<byte>();
var numImages = (m_CellObservationSize + 2) / 3;
for (int i = 0; i < numImages; i++)
{
var channelIndex = 3 * i;
GridValuesToTexture(channelIndex, Math.Min(3, m_CellObservationSize - channelIndex));
allBytes.AddRange(m_PerceptionTexture.EncodeToPNG());
}
return allBytes.ToArray();
}
}
/// <summary>
/// Convert observation values to texture for PNG compression.
/// </summary>
void GridValuesToTexture(int channelIndex, int numChannelsToAdd)
{
for (int i = 0; i < m_NumCells; i++)
{
for (int j = 0; j < numChannelsToAdd; j++)
{
m_PerceptionColors[i][j] = m_PerceptionBuffer[i * m_CellObservationSize + channelIndex + j];
}
}
m_PerceptionTexture.SetPixels(m_PerceptionColors);
}
/// <summary>
/// Get the observation values of the detected game object.
/// Default is to record the detected tag index.
///
/// This method can be overridden to encode the observation differently or get custom data from the object.
/// When overriding this method, <seealso cref="GetCellObservationSize"/> and <seealso cref="IsDataNormalized"/>
/// might also need to change accordingly.
/// </summary>
/// <param name="detectedObject">The game object that was detected within a certain cell</param>
/// <param name="tagIndex">The index of the detectedObject's tag in the DetectableObjects list</param>
/// <param name="dataBuffer">The buffer to write the observation values.
/// The buffer size is configured by <seealso cref="GetCellObservationSize"/>.
/// </param>
/// <example>
/// Here is an example of overriding GetObjectData to get the velocity of a potential Rigidbody:
/// <code>
/// protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
/// {
/// if (tagIndex == Array.IndexOf(DetectableTags, "RigidBodyObject"))
/// {
/// Rigidbody rigidbody = detectedObject.GetComponent&lt;Rigidbody&gt;();
/// dataBuffer[0] = rigidbody.velocity.x;
/// dataBuffer[1] = rigidbody.velocity.y;
/// dataBuffer[2] = rigidbody.velocity.z;
/// }
/// }
/// </code>
/// </example>
protected virtual void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
{
dataBuffer[0] = tagIndex + 1;
}
/// <summary>
/// Get the observation size for each cell. This will be the size of dataBuffer for <seealso cref="GetObjectData"/>.
/// If overriding <seealso cref="GetObjectData"/>, override this method as well to the custom observation size.
/// </summary>
/// <returns>The observation size of each cell.</returns>
protected virtual int GetCellObservationSize()
{
return 1;
}
/// <summary>
/// Whether the data is normalized within [0, 1]. The sensor can only use PNG compression if the data is normailzed.
/// If overriding <seealso cref="GetObjectData"/>, override this method as well according to the custom observation values.
/// </summary>
/// <returns>Bool value indicating whether data is normalized.</returns>
protected virtual bool IsDataNormalized()
{
return false;
}
/// <summary>
/// Whether to process all detected colliders in a cell. Default to false and only use the one closest to the agent.
/// If overriding <seealso cref="GetObjectData"/>, consider override this method when needed.
/// </summary>
/// <returns>Bool value indicating whether to process all detected colliders in a cell.</returns>
protected internal virtual ProcessCollidersMethod GetProcessCollidersMethod()
{
return ProcessCollidersMethod.ProcessClosestColliders;
}
/// <summary>
/// If using PNG compression, check if the values are normalized.
/// </summary>
void ValidateValues(float[] dataValues, GameObject detectedObject)
{
if (m_CompressionType != SensorCompressionType.PNG)
{
return;
}
for (int j = 0; j < dataValues.Length; j++)
{
if (dataValues[j] < 0 || dataValues[j] > 1)
throw new UnityAgentsException($"When using compression type {m_CompressionType} the data value has to be normalized between 0-1. " +
$"Received value[{dataValues[j]}] for {detectedObject.name}");
}
}
/// <summary>
/// Collect data from the detected object if a detectable tag is matched.
/// </summary>
internal void ProcessDetectedObject(GameObject detectedObject, int cellIndex)
{
Profiler.BeginSample("GridSensor.ProcessDetectedObject");
for (var i = 0; i < m_DetectableTags.Length; i++)
{
if (!ReferenceEquals(detectedObject, null) && detectedObject.CompareTag(m_DetectableTags[i]))
{
if (GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders)
{
Array.Copy(m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellDataBuffer, 0, m_CellObservationSize);
}
else
{
Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
}
GetObjectData(detectedObject, i, m_CellDataBuffer);
ValidateValues(m_CellDataBuffer, detectedObject);
Array.Copy(m_CellDataBuffer, 0, m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize);
break;
}
}
Profiler.EndSample();
}
/// <inheritdoc/>
public void Update()
{
ResetPerceptionBuffer();
using (TimerStack.Instance.Scoped("GridSensor.Update"))
{
if (m_BoxOverlapChecker != null)
{
m_BoxOverlapChecker.Update();
}
}
}
/// <inheritdoc/>
public ObservationSpec GetObservationSpec()
{
return m_ObservationSpec;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("GridSensor.Write"))
{
int index = 0;
for (var h = m_GridSize.z - 1; h >= 0; h--)
{
for (var w = 0; w < m_GridSize.x; w++)
{
for (var d = 0; d < m_CellObservationSize; d++)
{
writer[h, w, d] = m_PerceptionBuffer[index];
index++;
}
}
}
return index;
}
}
/// <summary>
/// Clean up the internal objects.
/// </summary>
public void Dispose()
{
if (!ReferenceEquals(null, m_PerceptionTexture))
{
Utilities.DestroyTexture(m_PerceptionTexture);
m_PerceptionTexture = null;
}
}
}
}