using System;
using Unity.Collections;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Sensors
{
public class CountingGridSensor : GridSensor
{
///
public override void InitDepthType()
{
ObservationPerCell = ChannelDepth.Length;
}
///
/// Overrides the initialization ofthe m_ChannelHotDefaultPerceptionBuffer with 0s
/// as the counting grid sensor starts within its initialization equal to 0
///
public override void InitChannelHotDefaultPerceptionBuffer()
{
m_ChannelHotDefaultPerceptionBuffer = new NativeArray(ObservationPerCell, Allocator.Persistent);
}
///
public override void SetParameters(string[] detectableObjects, int[] channelDepth, GridDepthType gridDepthType,
float cellScaleX, float cellScaleZ, int gridWidth, int gridHeight, int observeMaskInt, bool rotateToAgent, Color[] debugColors)
{
this.ObserveMask = observeMaskInt;
this.DetectableObjects = detectableObjects;
this.ChannelDepth = channelDepth;
if (DetectableObjects.Length != ChannelDepth.Length)
throw new UnityAgentsException("The channels of a CountingGridSensor is equal to the number of detectableObjects");
this.gridDepthType = gridDepthType;
this.CellScaleX = cellScaleX;
this.CellScaleZ = cellScaleZ;
this.GridNumSideX = gridWidth;
this.GridNumSideZ = gridHeight;
this.RotateToAgent = rotateToAgent;
this.DiffNumSideZX = (GridNumSideZ - GridNumSideX);
this.OffsetGridNumSide = (GridNumSideZ - 1f) / 2f;
this.DebugColors = debugColors;
}
///
/// For each collider, calls LoadObjectData on the gameobejct
///
/// The array of colliders
///
/// The cell index the collider is in
/// the center of the cell the collider is in
protected override void ParseColliders(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter)
{
GameObject currentColliderGo = null;
Vector3 closestColliderPoint = Vector3.zero;
for (int i = 0; i < numFound; i++)
{
currentColliderGo = foundColliders[i].gameObject;
// Continue if the current collider go is the root reference
if (currentColliderGo == rootReference)
continue;
closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter);
if (m_DetectableObjectToIndex.TryGetValue(currentColliderGo.tag, out var detectableIndex))
{
LoadObjectData(currentColliderGo, cellIndex,
detectableIndex, Vector3.Distance(closestColliderPoint, transform.position) * InverseSphereRadius);
}
}
}
///
/// Throws an execption as this should not be called from the CountingGridSensor class
///
/// The current gameobject to get data from
/// the index of the detectable tag of this gameobject
/// The normalized distance to the gridsensor
///
protected override float[] GetObjectData(GameObject currentColliderGo, float typeIndex, float normalizedDistance)
{
throw new Exception("GetObjectData isn't called within the CountingGridSensor");
}
///
/// Adds 1 to the counting index for this gameobject of this type
///
/// the current game object
/// the index of the cell
///
/// the normalized distance from the gameobject to the sensor
protected override void LoadObjectData(GameObject currentColliderGo, int cellIndex, int detectableIndex, float normalizedDistance)
{
if (ShowGizmos)
{
Color debugRayColor = Color.white;
if (DebugColors.Length > 0)
{
debugRayColor = DebugColors[detectableIndex];
}
CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
}
//
// The observations are "channel count" so each grid is WxHxC where C is the number of tags
// This means that each value channelValues[i] is a counter of gameobject included into grid cells where i is the index of the tag in DetectableObjects
//
int countIndex = cellIndex * ObservationPerCell + detectableIndex;
m_PerceptionBuffer[countIndex] = Mathf.Min(1f, m_PerceptionBuffer[countIndex] + (1f / ChannelDepth[detectableIndex]));
}
}
}