using UnityEngine; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Extensions.Sensors { /// /// A SensorComponent that creates a . /// [AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)] public class GridSensorComponent : SensorComponent { protected GridSensor m_Sensor; [HideInInspector, SerializeField] internal string m_SensorName = "GridSensor"; // /// Name of the generated object. /// Note that changing this at runtime does not affect how the Agent sorts the sensors. /// public string SensorName { get { return m_SensorName; } set { m_SensorName = value; } } [HideInInspector, SerializeField] internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f); /// /// The scale of each grid cell. /// Note that changing this after the sensor is created has no effect. /// public Vector3 CellScale { get { return m_CellScale; } set { m_CellScale = value; } } [HideInInspector, SerializeField] internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16); /// /// The number of grid on each side. /// Note that changing this after the sensor is created has no effect. /// public Vector3Int GridSize { get { return m_GridSize; } set { if (value.y != 1) { m_GridSize = new Vector3Int(value.x, 1, value.z); } else { m_GridSize = value; } } } [HideInInspector, SerializeField] internal bool m_RotateWithAgent = true; /// /// Rotate the grid based on the direction the agent is facing. /// public bool RotateWithAgent { get { return m_RotateWithAgent; } set { m_RotateWithAgent = value; } } [HideInInspector, SerializeField] internal int[] m_ChannelDepths = new int[] { 1 }; /// /// Array holding the depth of each channel. /// Note that changing this after the sensor is created has no effect. /// public int[] ChannelDepths { get { return m_ChannelDepths; } set { m_ChannelDepths = value; } } [HideInInspector, SerializeField] internal string[] m_DetectableObjects; /// /// List of tags that are detected. /// Note that changing this after the sensor is created has no effect. /// public string[] DetectableObjects { get { return m_DetectableObjects; } set { m_DetectableObjects = value; } } [HideInInspector, SerializeField] internal LayerMask m_ColliderMask; /// /// The layer mask. /// public LayerMask ColliderMask { get { return m_ColliderMask; } set { m_ColliderMask = value; } } [HideInInspector, SerializeField] internal GridDepthType m_DepthType = GridDepthType.Channel; /// /// The data layout that the grid should output. /// Note that changing this after the sensor is created has no effect. /// public GridDepthType DepthType { get { return m_DepthType; } set { m_DepthType = value; } } [HideInInspector, SerializeField] internal GameObject m_RootReference; /// /// The reference of the root of the agent. This is used to disambiguate objects with the same tag as the agent. Defaults to current GameObject. /// Note that changing this after the sensor is created has no effect. /// public GameObject RootReference { get { return m_RootReference == null ? gameObject : m_RootReference; } set { m_RootReference = value; } } [HideInInspector, SerializeField] internal int m_MaxColliderBufferSize = 500; /// /// The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words /// the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell. /// Note that changing this after the sensor is created has no effect. /// public int MaxColliderBufferSize { get { return m_MaxColliderBufferSize; } set { m_MaxColliderBufferSize = value; } } [HideInInspector, SerializeField] internal int m_InitialColliderBufferSize = 4; /// /// The Estimated Max Number of Colliders to expect per cell. This number is used to /// pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc /// Physics API. If the number of colliders found is >= InitialColliderBufferSize the array /// will be resized to double its current size. The hard coded absolute size is 500. /// Note that changing this after the sensor is created has no effect. /// public int InitialColliderBufferSize { get { return m_InitialColliderBufferSize; } set { m_InitialColliderBufferSize = value; } } [HideInInspector, SerializeField] internal Color[] m_DebugColors; /// /// Array of Colors used for the grid gizmos. /// public Color[] DebugColors { get { return m_DebugColors; } set { m_DebugColors = value; } } [HideInInspector, SerializeField] internal float m_GizmoYOffset = 0f; /// /// The height of the gizmos grid. /// public float GizmoYOffset { get { return m_GizmoYOffset; } set { m_GizmoYOffset = value; } } [HideInInspector, SerializeField] internal bool m_ShowGizmos = false; /// /// Whether to show gizmos or not. /// public bool ShowGizmos { get { return m_ShowGizmos; } set { m_ShowGizmos = value; } } [HideInInspector, SerializeField] internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG; /// /// The compression type to use for the sensor. /// public SensorCompressionType CompressionType { get { return m_CompressionType; } set { m_CompressionType = value; UpdateSensor(); } } [HideInInspector, SerializeField] [Range(1, 50)] [Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")] internal int m_ObservationStacks = 1; /// /// Whether to stack previous observations. Using 1 means no previous observations. /// Note that changing this after the sensor is created has no effect. /// public int ObservationStacks { get { return m_ObservationStacks; } set { m_ObservationStacks = value; } } /// public override ISensor[] CreateSensors() { m_Sensor = new GridSensor( m_SensorName, m_CellScale, m_GridSize, m_RotateWithAgent, m_ChannelDepths, m_DetectableObjects, m_ColliderMask, m_DepthType, RootReference, m_CompressionType, m_MaxColliderBufferSize, m_InitialColliderBufferSize ); if (ObservationStacks != 1) { return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) }; } return new ISensor[] { m_Sensor }; } /// /// Update fields that are safe to change on the Sensor at runtime. /// internal void UpdateSensor() { if (m_Sensor != null) { m_Sensor.CompressionType = m_CompressionType; m_Sensor.RotateWithAgent = m_RotateWithAgent; m_Sensor.ColliderMask = m_ColliderMask; } } void OnDrawGizmos() { if (m_ShowGizmos) { if (m_Sensor == null) { return; } var cellColors = m_Sensor.PerceiveGizmoColor(); var cellPositions = m_Sensor.GetGizmoPositions(); var rotation = m_Sensor.GetGridRotation(); var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z); var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0); var oldGizmoMatrix = Gizmos.matrix; for (var i = 0; i < cellPositions.Length; i++) { var cubeTransform = Matrix4x4.TRS(cellPositions[i] + gizmoYOffset, rotation, scale); Gizmos.matrix = oldGizmoMatrix * cubeTransform; var colorIndex = cellColors[i]; var debugRayColor = Color.white; if (colorIndex > -1 && m_DebugColors.Length > colorIndex) { debugRayColor = m_DebugColors[colorIndex]; } Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); Gizmos.DrawCube(Vector3.zero, Vector3.one); } Gizmos.matrix = oldGizmoMatrix; } } } }