using System; using UnityEngine; namespace Unity.MLAgents.Extensions.Sensors { public class CountingGridSensor : GridSensor { /// public override void InitDepthType() { ObservationPerCell = ChannelDepth.Length; } /// /// Overrides the initialization ofthe m_ChannelHotDefaultPerceptionBuffer with 0s /// as the counting grid sensor starts within its initialization equal to 0 /// public override void InitChannelHotDefaultPerceptionBuffer() { m_ChannelHotDefaultPerceptionBuffer = new float[ObservationPerCell]; } /// public override void SetParameters(string[] detectableObjects, int[] channelDepth, GridDepthType gridDepthType, float cellScaleX, float cellScaleZ, int gridWidth, int gridHeight, int observeMaskInt, bool rotateToAgent, Color[] debugColors) { this.ObserveMask = observeMaskInt; this.DetectableObjects = detectableObjects; this.ChannelDepth = channelDepth; if (DetectableObjects.Length != ChannelDepth.Length) throw new UnityAgentsException("The channels of a CountingGridSensor is equal to the number of detectableObjects"); this.gridDepthType = GridDepthType.Channel; this.CellScaleX = cellScaleX; this.CellScaleZ = cellScaleZ; this.GridNumSideX = gridWidth; this.GridNumSideZ = gridHeight; this.RotateToAgent = rotateToAgent; this.DiffNumSideZX = (GridNumSideZ - GridNumSideX); this.OffsetGridNumSide = (GridNumSideZ - 1f) / 2f; this.DebugColors = debugColors; } /// /// For each collider, calls LoadObjectData on the gameobejct /// /// The array of colliders /// The cell index the collider is in /// the center of the cell the collider is in protected override void ParseColliders(Collider[] foundColliders, int cellIndex, Vector3 cellCenter) { GameObject currentColliderGo = null; Vector3 closestColliderPoint = Vector3.zero; for (int i = 0; i < foundColliders.Length; i++) { currentColliderGo = foundColliders[i].gameObject; // Continue if the current collider go is the root reference if (currentColliderGo == rootReference) continue; closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter); LoadObjectData(currentColliderGo, cellIndex, Vector3.Distance(closestColliderPoint, transform.position) / SphereRadius); } } /// /// Throws an execption as this should not be called from the CountingGridSensor class /// /// The current gameobject to get data from /// the index of the detectable tag of this gameobject /// The normalized distance to the gridsensor /// protected override float[] GetObjectData(GameObject currentColliderGo, float typeIndex, float normalizedDistance) { throw new Exception("GetObjectData isn't called within the CountingGridSensor"); } /// /// Adds 1 to the counting index for this gameobject of this type /// /// the current game object /// the index of the cell /// the normalized distance from the gameobject to the sensor protected override void LoadObjectData(GameObject currentColliderGo, int cellIndex, float normalizedDistance) { for (int i = 0; i < DetectableObjects.Length; i++) { if (currentColliderGo != null && currentColliderGo.CompareTag(DetectableObjects[i])) { if (ShowGizmos) { Color debugRayColor = Color.white; if (DebugColors.Length > 0) { debugRayColor = DebugColors[i]; } CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); } /// /// The observations are "channel count" so each grid is WxHxC where C is the number of tags /// This means that each value channelValues[i] is a counter of gameobject included into grid cells where i is the index of the tag in DetectableObjects /// int countIndex = cellIndex * ObservationPerCell + i; m_PerceptionBuffer[countIndex] = Mathf.Min(1f, m_PerceptionBuffer[countIndex] + 1f / ChannelDepth[i]); break; } } } } }