using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
///
/// Grid-based sensor that counts the number of detctable objects.
///
public class CountingGridSensor : GridSensorBase
{
///
/// Create a CountingGridSensor with the specified configuration.
///
/// The sensor name
/// The scale of each cell in the grid
/// Number of cells on each side of the grid
/// Tags to be detected by the sensor
/// Compression type
public CountingGridSensor(
string name,
Vector3 cellScale,
Vector3Int gridSize,
string[] detectableTags,
SensorCompressionType compression
) : base(name, cellScale, gridSize, detectableTags, compression)
{
CompressionType = SensorCompressionType.None;
}
///
protected override int GetCellObservationSize()
{
return DetectableTags == null ? 0 : DetectableTags.Length;
}
///
protected override bool IsDataNormalized()
{
return false;
}
///
protected internal override ProcessCollidersMethod GetProcessCollidersMethod()
{
return ProcessCollidersMethod.ProcessAllColliders;
}
///
/// Get object counts for each detectable tags detected in a cell.
///
/// The game object that was detected within a certain cell
/// The index of the detectedObject's tag in the DetectableObjects list
/// The buffer to write the observation values.
/// The buffer size is configured by .
///
protected override void GetObjectData(GameObject detectedObject, int tagIndex, float[] dataBuffer)
{
dataBuffer[tagIndex] += 1;
}
}
}