using System; using Unity.Collections; using UnityEngine; namespace Unity.MLAgents.Extensions.Sensors { public class CountingGridSensor : GridSensor { /// public override void InitDepthType() { ObservationPerCell = ChannelDepth.Length; } /// /// Overrides the initialization ofthe m_ChannelHotDefaultPerceptionBuffer with 0s /// as the counting grid sensor starts within its initialization equal to 0 /// public override void InitChannelHotDefaultPerceptionBuffer() { m_ChannelHotDefaultPerceptionBuffer = new NativeArray(ObservationPerCell, Allocator.Persistent); } /// public override void SetParameters(string[] detectableObjects, int[] channelDepth, GridDepthType gridDepthType, float cellScaleX, float cellScaleZ, int gridWidth, int gridHeight, int observeMaskInt, bool rotateToAgent, Color[] debugColors) { this.ObserveMask = observeMaskInt; this.DetectableObjects = detectableObjects; this.ChannelDepth = channelDepth; if (DetectableObjects.Length != ChannelDepth.Length) throw new UnityAgentsException("The channels of a CountingGridSensor is equal to the number of detectableObjects"); this.gridDepthType = gridDepthType; this.CellScaleX = cellScaleX; this.CellScaleZ = cellScaleZ; this.GridNumSideX = gridWidth; this.GridNumSideZ = gridHeight; this.RotateToAgent = rotateToAgent; this.DiffNumSideZX = (GridNumSideZ - GridNumSideX); this.OffsetGridNumSide = (GridNumSideZ - 1f) / 2f; this.DebugColors = debugColors; } /// /// For each collider, calls LoadObjectData on the gameobejct /// /// The array of colliders /// /// The cell index the collider is in /// the center of the cell the collider is in protected override void ParseColliders(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter) { GameObject currentColliderGo = null; Vector3 closestColliderPoint = Vector3.zero; for (int i = 0; i < numFound; i++) { currentColliderGo = foundColliders[i].gameObject; // Continue if the current collider go is the root reference if (currentColliderGo == rootReference) continue; closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter); if (m_DetectableObjectToIndex.TryGetValue(currentColliderGo.tag, out var detectableIndex)) { LoadObjectData(currentColliderGo, cellIndex, detectableIndex, Vector3.Distance(closestColliderPoint, transform.position) * InverseSphereRadius); } } } /// /// Throws an execption as this should not be called from the CountingGridSensor class /// /// The current gameobject to get data from /// the index of the detectable tag of this gameobject /// The normalized distance to the gridsensor /// protected override float[] GetObjectData(GameObject currentColliderGo, float typeIndex, float normalizedDistance) { throw new Exception("GetObjectData isn't called within the CountingGridSensor"); } /// /// Adds 1 to the counting index for this gameobject of this type /// /// the current game object /// the index of the cell /// /// the normalized distance from the gameobject to the sensor protected override void LoadObjectData(GameObject currentColliderGo, int cellIndex, int detectableIndex, float normalizedDistance) { if (ShowGizmos) { Color debugRayColor = Color.white; if (DebugColors.Length > 0) { debugRayColor = DebugColors[detectableIndex]; } CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); } // // The observations are "channel count" so each grid is WxHxC where C is the number of tags // This means that each value channelValues[i] is a counter of gameobject included into grid cells where i is the index of the tag in DetectableObjects // int countIndex = cellIndex * ObservationPerCell + detectableIndex; m_PerceptionBuffer[countIndex] = Mathf.Min(1f, m_PerceptionBuffer[countIndex] + (1f / ChannelDepth[detectableIndex])); } } }