using System; using UnityEngine; namespace Unity.MLAgents.Sensors { /// /// The grid perception strategy that uses box overlap to detect objects. /// internal class BoxOverlapChecker : IGridPerception { Vector3 m_CellScale; Vector3Int m_GridSize; bool m_RotateWithAgent; LayerMask m_ColliderMask; GameObject m_CenterObject; GameObject m_AgentGameObject; string[] m_DetectableTags; int m_InitialColliderBufferSize; int m_MaxColliderBufferSize; int m_NumCells; Vector3 m_HalfCellScale; Vector3 m_CellCenterOffset; Vector3[] m_CellLocalPositions; #if MLA_UNITY_PHYSICS_MODULE Collider[] m_ColliderBuffer; public event Action GridOverlapDetectedAll; public event Action GridOverlapDetectedClosest; public event Action GridOverlapDetectedDebug; #endif public BoxOverlapChecker( Vector3 cellScale, Vector3Int gridSize, bool rotateWithAgent, LayerMask colliderMask, GameObject centerObject, GameObject agentGameObject, string[] detectableTags, int initialColliderBufferSize, int maxColliderBufferSize) { m_CellScale = cellScale; m_GridSize = gridSize; m_RotateWithAgent = rotateWithAgent; m_ColliderMask = colliderMask; m_CenterObject = centerObject; m_AgentGameObject = agentGameObject; m_DetectableTags = detectableTags; m_InitialColliderBufferSize = initialColliderBufferSize; m_MaxColliderBufferSize = maxColliderBufferSize; m_NumCells = gridSize.x * gridSize.z; m_HalfCellScale = new Vector3(cellScale.x / 2f, cellScale.y, cellScale.z / 2f); m_CellCenterOffset = new Vector3((gridSize.x - 1f) / 2, 0, (gridSize.z - 1f) / 2); #if MLA_UNITY_PHYSICS_MODULE m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_InitialColliderBufferSize)]; #endif InitCellLocalPositions(); } public bool RotateWithAgent { get { return m_RotateWithAgent; } set { m_RotateWithAgent = value; } } public LayerMask ColliderMask { get { return m_ColliderMask; } set { m_ColliderMask = value; } } /// /// Initializes the local location of the cells /// void InitCellLocalPositions() { m_CellLocalPositions = new Vector3[m_NumCells]; for (int i = 0; i < m_NumCells; i++) { m_CellLocalPositions[i] = GetCellLocalPosition(i); } } public Vector3 GetCellLocalPosition(int cellIndex) { float x = (cellIndex / m_GridSize.z - m_CellCenterOffset.x) * m_CellScale.x; float z = (cellIndex % m_GridSize.z - m_CellCenterOffset.z) * m_CellScale.z; return new Vector3(x, 0, z); } public Vector3 GetCellGlobalPosition(int cellIndex) { if (m_RotateWithAgent) { return m_CenterObject.transform.TransformPoint(m_CellLocalPositions[cellIndex]); } else { return m_CellLocalPositions[cellIndex] + m_CenterObject.transform.position; } } public Quaternion GetGridRotation() { return m_RotateWithAgent ? m_CenterObject.transform.rotation : Quaternion.identity; } public void Perceive() { #if MLA_UNITY_PHYSICS_MODULE for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++) { var cellCenter = GetCellGlobalPosition(cellIndex); var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation()); if (GridOverlapDetectedAll != null) { ParseCollidersAll(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedAll); } if (GridOverlapDetectedClosest != null) { ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedClosest); } } #endif } public void UpdateGizmo() { #if MLA_UNITY_PHYSICS_MODULE for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++) { var cellCenter = GetCellGlobalPosition(cellIndex); var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation()); ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedDebug); } #endif } #if MLA_UNITY_PHYSICS_MODULE /// /// This method attempts to perform the Physics.OverlapBoxNonAlloc and will double the size of the Collider buffer /// if the number of Colliders in the buffer after the call is equal to the length of the buffer. /// /// /// /// /// int BufferResizingOverlapBoxNonAlloc(Vector3 cellCenter, Vector3 halfCellScale, Quaternion rotation) { int numFound; // Since we can only get a fixed number of results, requery // until we're sure we can hold them all (or until we hit the max size). while (true) { numFound = Physics.OverlapBoxNonAlloc(cellCenter, halfCellScale, m_ColliderBuffer, rotation, m_ColliderMask); if (numFound == m_ColliderBuffer.Length && m_ColliderBuffer.Length < m_MaxColliderBufferSize) { m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_ColliderBuffer.Length * 2)]; m_InitialColliderBufferSize = m_ColliderBuffer.Length; } else { break; } } return numFound; } /// /// Parses the array of colliders found within a cell. Finds the closest gameobject to the agent root reference within the cell /// void ParseCollidersClosest(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action detectedAction) { GameObject closestColliderGo = null; var minDistanceSquared = float.MaxValue; for (var i = 0; i < numFound; i++) { var currentColliderGo = foundColliders[i].gameObject; // Continue if the current collider go is the root reference if (ReferenceEquals(currentColliderGo, m_AgentGameObject)) { continue; } var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter); var currentDistanceSquared = (closestColliderPoint - m_CenterObject.transform.position).sqrMagnitude; if (currentDistanceSquared >= minDistanceSquared) { continue; } // Checks if our colliders contain a detectable object var index = -1; for (var ii = 0; ii < m_DetectableTags.Length; ii++) { if (currentColliderGo.CompareTag(m_DetectableTags[ii])) { index = ii; break; } } if (index > -1 && currentDistanceSquared < minDistanceSquared) { minDistanceSquared = currentDistanceSquared; closestColliderGo = currentColliderGo; } } if (!ReferenceEquals(closestColliderGo, null)) { detectedAction.Invoke(closestColliderGo, cellIndex); } } /// /// Parses all colliders in the array of colliders found within a cell. /// void ParseCollidersAll(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action detectedAction) { for (int i = 0; i < numFound; i++) { var currentColliderGo = foundColliders[i].gameObject; if (!ReferenceEquals(currentColliderGo, m_AgentGameObject)) { detectedAction.Invoke(currentColliderGo, cellIndex); } } } #endif public void RegisterSensor(GridSensorBase sensor) { #if MLA_UNITY_PHYSICS_MODULE if (sensor.GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders) { GridOverlapDetectedAll += sensor.ProcessDetectedObject; } else { GridOverlapDetectedClosest += sensor.ProcessDetectedObject; } #endif } public void RegisterDebugSensor(GridSensorBase debugSensor) { #if MLA_UNITY_PHYSICS_MODULE GridOverlapDetectedDebug += debugSensor.ProcessDetectedObject; #endif } } }