using System.Collections.Generic; using System.Linq; using UnityEngine; namespace Unity.MLAgents.Sensors { /// /// A SensorComponent that creates a . /// [AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)] public class GridSensorComponent : SensorComponent { // dummy sensor only used for debug gizmo GridSensorBase m_DebugSensor; List m_Sensors; internal BoxOverlapChecker m_BoxOverlapChecker; [HideInInspector, SerializeField] protected internal string m_SensorName = "GridSensor"; /// /// Name of the generated GridSensor object. /// Note that changing this at runtime does not affect how the Agent sorts the sensors. /// public string SensorName { get { return m_SensorName; } set { m_SensorName = value; } } [HideInInspector, SerializeField] internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f); /// /// The scale of each grid cell. /// Note that changing this after the sensor is created has no effect. /// public Vector3 CellScale { get { return m_CellScale; } set { m_CellScale = value; } } [HideInInspector, SerializeField] internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16); /// /// The number of grid on each side. /// Note that changing this after the sensor is created has no effect. /// public Vector3Int GridSize { get { return m_GridSize; } set { if (value.y != 1) { m_GridSize = new Vector3Int(value.x, 1, value.z); } else { m_GridSize = value; } } } [HideInInspector, SerializeField] internal bool m_RotateWithAgent = true; /// /// Rotate the grid based on the direction the agent is facing. /// public bool RotateWithAgent { get { return m_RotateWithAgent; } set { m_RotateWithAgent = value; } } [HideInInspector, SerializeField] internal GameObject m_AgentGameObject; /// /// The reference of the root of the agent. This is used to disambiguate objects with /// the same tag as the agent. Defaults to current GameObject. /// public GameObject AgentGameObject { get { return (m_AgentGameObject == null ? gameObject : m_AgentGameObject); } set { m_AgentGameObject = value; } } [HideInInspector, SerializeField] internal string[] m_DetectableTags; /// /// List of tags that are detected. /// Note that changing this after the sensor is created has no effect. /// public string[] DetectableTags { get { return m_DetectableTags; } set { m_DetectableTags = value; } } [HideInInspector, SerializeField] internal LayerMask m_ColliderMask; /// /// The layer mask. /// public LayerMask ColliderMask { get { return m_ColliderMask; } set { m_ColliderMask = value; } } [HideInInspector, SerializeField] internal int m_MaxColliderBufferSize = 500; /// /// The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words /// the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell. /// Note that changing this after the sensor is created has no effect. /// public int MaxColliderBufferSize { get { return m_MaxColliderBufferSize; } set { m_MaxColliderBufferSize = value; } } [HideInInspector, SerializeField] internal int m_InitialColliderBufferSize = 4; /// /// The Estimated Max Number of Colliders to expect per cell. This number is used to /// pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc /// Physics API. If the number of colliders found is >= InitialColliderBufferSize the array /// will be resized to double its current size. The hard coded absolute size is 500. /// Note that changing this after the sensor is created has no effect. /// public int InitialColliderBufferSize { get { return m_InitialColliderBufferSize; } set { m_InitialColliderBufferSize = value; } } [HideInInspector, SerializeField] internal Color[] m_DebugColors; /// /// Array of Colors used for the grid gizmos. /// public Color[] DebugColors { get { return m_DebugColors; } set { m_DebugColors = value; } } [HideInInspector, SerializeField] internal float m_GizmoYOffset = 0f; /// /// The height of the gizmos grid. /// public float GizmoYOffset { get { return m_GizmoYOffset; } set { m_GizmoYOffset = value; } } [HideInInspector, SerializeField] internal bool m_ShowGizmos = false; /// /// Whether to show gizmos or not. /// public bool ShowGizmos { get { return m_ShowGizmos; } set { m_ShowGizmos = value; } } [HideInInspector, SerializeField] internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG; /// /// The compression type to use for the sensor. /// public SensorCompressionType CompressionType { get { return m_CompressionType; } set { m_CompressionType = value; UpdateSensor(); } } [HideInInspector, SerializeField] [Range(1, 50)] [Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")] internal int m_ObservationStacks = 1; /// /// Whether to stack previous observations. Using 1 means no previous observations. /// Note that changing this after the sensor is created has no effect. /// public int ObservationStacks { get { return m_ObservationStacks; } set { m_ObservationStacks = value; } } /// public override ISensor[] CreateSensors() { m_BoxOverlapChecker = new BoxOverlapChecker( m_CellScale, m_GridSize, m_RotateWithAgent, m_ColliderMask, gameObject, AgentGameObject, m_DetectableTags, m_InitialColliderBufferSize, m_MaxColliderBufferSize ); // debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None. m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None); m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor); m_Sensors = GetGridSensors().ToList(); if (m_Sensors == null || m_Sensors.Count < 1) { throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." + "If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor."); } // Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once m_Sensors[0].m_BoxOverlapChecker = m_BoxOverlapChecker; foreach (var sensor in m_Sensors) { m_BoxOverlapChecker.RegisterSensor(sensor); } if (ObservationStacks != 1) { var sensors = new ISensor[m_Sensors.Count]; for (var i = 0; i < m_Sensors.Count; i++) { sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks); } return sensors; } else { return m_Sensors.ToArray(); } } /// /// Get an array of GridSensors to be added in this component. /// Override this method and return custom GridSensor implementations. /// /// Array of grid sensors to be added to the component. protected virtual GridSensorBase[] GetGridSensors() { List sensorList = new List(); var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType); sensorList.Add(sensor); return sensorList.ToArray(); } /// /// Update fields that are safe to change on the Sensor at runtime. /// internal void UpdateSensor() { if (m_Sensors != null) { m_BoxOverlapChecker.RotateWithAgent = m_RotateWithAgent; m_BoxOverlapChecker.ColliderMask = m_ColliderMask; foreach (var sensor in m_Sensors) { sensor.CompressionType = m_CompressionType; } } } void OnDrawGizmos() { if (m_ShowGizmos) { if (m_BoxOverlapChecker == null || m_DebugSensor == null) { return; } m_DebugSensor.ResetPerceptionBuffer(); m_BoxOverlapChecker.UpdateGizmo(); var cellColors = m_DebugSensor.PerceptionBuffer; var rotation = m_BoxOverlapChecker.GetGridRotation(); var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z); var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0); var oldGizmoMatrix = Gizmos.matrix; for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++) { var cellPosition = m_BoxOverlapChecker.GetCellGlobalPosition(i); var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale); Gizmos.matrix = oldGizmoMatrix * cubeTransform; var colorIndex = cellColors[i] - 1; var debugRayColor = Color.white; if (colorIndex > -1 && m_DebugColors.Length > colorIndex) { debugRayColor = m_DebugColors[(int)colorIndex]; } Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); Gizmos.DrawCube(Vector3.zero, Vector3.one); } Gizmos.matrix = oldGizmoMatrix; } } } }