using System; using System.Collections.Generic; using UnityEngine; using UnityEngine.Serialization; namespace Unity.MLAgents.Sensors { /// /// A base class to support sensor components for raycast-based sensors. /// public abstract class RayPerceptionSensorComponentBase : SensorComponent { [HideInInspector, SerializeField, FormerlySerializedAs("sensorName")] string m_SensorName = "RayPerceptionSensor"; /// /// The name of the Sensor that this component wraps. /// Note that changing this at runtime does not affect how the Agent sorts the sensors. /// public string SensorName { get { return m_SensorName; } set { m_SensorName = value; } } [SerializeField, FormerlySerializedAs("detectableTags")] [Tooltip("List of tags in the scene to compare against.")] List m_DetectableTags; /// /// List of tags in the scene to compare against. /// Note that this should not be changed at runtime. /// public List DetectableTags { get { return m_DetectableTags; } set { m_DetectableTags = value; } } [HideInInspector, SerializeField, FormerlySerializedAs("raysPerDirection")] [Range(0, 50)] [Tooltip("Number of rays to the left and right of center.")] int m_RaysPerDirection = 3; /// /// Number of rays to the left and right of center. /// Note that this should not be changed at runtime. /// public int RaysPerDirection { get { return m_RaysPerDirection; } // Note: can't change at runtime set { m_RaysPerDirection = value;} } [HideInInspector, SerializeField, FormerlySerializedAs("maxRayDegrees")] [Range(0, 180)] [Tooltip("Cone size for rays. Using 90 degrees will cast rays to the left and right. " + "Greater than 90 degrees will go backwards.")] float m_MaxRayDegrees = 70; /// /// Cone size for rays. Using 90 degrees will cast rays to the left and right. /// Greater than 90 degrees will go backwards. /// public float MaxRayDegrees { get => m_MaxRayDegrees; set { m_MaxRayDegrees = value; UpdateSensor(); } } [HideInInspector, SerializeField, FormerlySerializedAs("sphereCastRadius")] [Range(0f, 10f)] [Tooltip("Radius of sphere to cast. Set to zero for raycasts.")] float m_SphereCastRadius = 0.5f; /// /// Radius of sphere to cast. Set to zero for raycasts. /// public float SphereCastRadius { get => m_SphereCastRadius; set { m_SphereCastRadius = value; UpdateSensor(); } } [HideInInspector, SerializeField, FormerlySerializedAs("rayLength")] [Range(1, 1000)] [Tooltip("Length of the rays to cast.")] float m_RayLength = 20f; /// /// Length of the rays to cast. /// public float RayLength { get => m_RayLength; set { m_RayLength = value; UpdateSensor(); } } [HideInInspector, SerializeField, FormerlySerializedAs("rayLayerMask")] [Tooltip("Controls which layers the rays can hit.")] LayerMask m_RayLayerMask = Physics.DefaultRaycastLayers; /// /// Controls which layers the rays can hit. /// public LayerMask RayLayerMask { get => m_RayLayerMask; set { m_RayLayerMask = value; UpdateSensor(); } } [HideInInspector, SerializeField, FormerlySerializedAs("observationStacks")] [Range(1, 50)] [Tooltip("Number of raycast results that will be stacked before being fed to the neural network.")] int m_ObservationStacks = 1; /// /// Whether to stack previous observations. Using 1 means no previous observations. /// Note that changing this after the sensor is created has no effect. /// public int ObservationStacks { get { return m_ObservationStacks; } set { m_ObservationStacks = value; } } /// /// Color to code a ray that hits another object. /// [HideInInspector] [SerializeField] [Header("Debug Gizmos", order = 999)] internal Color rayHitColor = Color.red; /// /// Color to code a ray that avoid or misses all other objects. /// [HideInInspector] [SerializeField] internal Color rayMissColor = Color.white; [NonSerialized] RayPerceptionSensor m_RaySensor; /// /// Get the RayPerceptionSensor that was created. /// public RayPerceptionSensor RaySensor { get => m_RaySensor; } /// /// Returns the for the associated raycast sensor. /// /// public abstract RayPerceptionCastType GetCastType(); /// /// Returns the amount that the ray start is offset up or down by. /// /// public virtual float GetStartVerticalOffset() { return 0f; } /// /// Returns the amount that the ray end is offset up or down by. /// /// public virtual float GetEndVerticalOffset() { return 0f; } /// /// Returns an initialized raycast sensor. /// /// public override ISensor CreateSensor() { var rayPerceptionInput = GetRayPerceptionInput(); m_RaySensor = new RayPerceptionSensor(m_SensorName, rayPerceptionInput); if (ObservationStacks != 1) { var stackingSensor = new StackingSensor(m_RaySensor, ObservationStacks); return stackingSensor; } return m_RaySensor; } /// /// Returns the specific ray angles given the number of rays per direction and the /// cone size for the rays. /// /// Number of rays to the left and right of center. /// /// Cone size for rays. Using 90 degrees will cast rays to the left and right. /// Greater than 90 degrees will go backwards. /// /// internal static float[] GetRayAngles(int raysPerDirection, float maxRayDegrees) { // Example: // { 90, 90 - delta, 90 + delta, 90 - 2*delta, 90 + 2*delta } var anglesOut = new float[2 * raysPerDirection + 1]; var delta = maxRayDegrees / raysPerDirection; anglesOut[0] = 90f; for (var i = 0; i < raysPerDirection; i++) { anglesOut[2 * i + 1] = 90 - (i + 1) * delta; anglesOut[2 * i + 2] = 90 + (i + 1) * delta; } return anglesOut; } /// /// Returns the observation shape for this raycast sensor which depends on the number /// of tags for detected objects and the number of rays. /// /// public override int[] GetObservationShape() { var numRays = 2 * RaysPerDirection + 1; var numTags = m_DetectableTags?.Count ?? 0; var obsSize = (numTags + 2) * numRays; var stacks = ObservationStacks > 1 ? ObservationStacks : 1; return new[] { obsSize * stacks }; } /// /// Get the RayPerceptionInput that is used by the . /// /// public RayPerceptionInput GetRayPerceptionInput() { var rayAngles = GetRayAngles(RaysPerDirection, MaxRayDegrees); var rayPerceptionInput = new RayPerceptionInput(); rayPerceptionInput.RayLength = RayLength; rayPerceptionInput.DetectableTags = DetectableTags; rayPerceptionInput.Angles = rayAngles; rayPerceptionInput.StartOffset = GetStartVerticalOffset(); rayPerceptionInput.EndOffset = GetEndVerticalOffset(); rayPerceptionInput.CastRadius = SphereCastRadius; rayPerceptionInput.Transform = transform; rayPerceptionInput.CastType = GetCastType(); rayPerceptionInput.LayerMask = RayLayerMask; return rayPerceptionInput; } internal void UpdateSensor() { if (m_RaySensor != null) { var rayInput = GetRayPerceptionInput(); m_RaySensor.SetRayPerceptionInput(rayInput); } } void OnDrawGizmosSelected() { if (m_RaySensor?.debugDisplayInfo?.rayInfos != null) { // If we have cached debug info from the sensor, draw that. // Draw "old" observations in a lighter color. // Since the agent may not step every frame, this helps de-emphasize "stale" hit information. var alpha = Mathf.Pow(.5f, m_RaySensor.debugDisplayInfo.age); foreach (var rayInfo in m_RaySensor.debugDisplayInfo.rayInfos) { DrawRaycastGizmos(rayInfo, alpha); } } else { var rayInput = GetRayPerceptionInput(); for (var rayIndex = 0; rayIndex < rayInput.Angles.Count; rayIndex++) { DebugDisplayInfo.RayInfo debugRay; RayPerceptionSensor.PerceiveSingleRay(rayInput, rayIndex, out debugRay); DrawRaycastGizmos(debugRay); } } } /// /// Draw the debug information from the sensor (if available). /// void DrawRaycastGizmos(DebugDisplayInfo.RayInfo rayInfo, float alpha = 1.0f) { var startPositionWorld = rayInfo.worldStart; var endPositionWorld = rayInfo.worldEnd; var rayDirection = endPositionWorld - startPositionWorld; rayDirection *= rayInfo.rayOutput.HitFraction; // hit fraction ^2 will shift "far" hits closer to the hit color var lerpT = rayInfo.rayOutput.HitFraction * rayInfo.rayOutput.HitFraction; var color = Color.Lerp(rayHitColor, rayMissColor, lerpT); color.a *= alpha; Gizmos.color = color; Gizmos.DrawRay(startPositionWorld, rayDirection); // Draw the hit point as a sphere. If using rays to cast (0 radius), use a small sphere. if (rayInfo.rayOutput.HasHit) { var hitRadius = Mathf.Max(rayInfo.castRadius, .05f); Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius); } } } }