using System; using System.Collections.Generic; using UnityEngine; namespace MLAgents.Sensors { /// /// Determines which dimensions the sensor will perform the casts in. /// public enum RayPerceptionCastType { /// /// Cast in 2 dimensions, using Physics2D.CircleCast or Physics2D.RayCast. /// Cast2D, /// /// Cast in 3 dimensions, using Physics.SphereCast or Physics.RayCast. /// Cast3D, } /// /// Contains the elements that define a ray perception sensor. /// public struct RayPerceptionInput { /// /// Length of the rays to cast. This will be scaled up or down based on the scale of the transform. /// public float rayLength; /// /// List of tags which correspond to object types agent can see. /// public IReadOnlyList detectableTags; /// /// List of angles (in degrees) used to define the rays. /// 90 degrees is considered "forward" relative to the game object. /// public IReadOnlyList angles; /// /// Starting height offset of ray from center of agent /// public float startOffset; /// /// Ending height offset of ray from center of agent. /// public float endOffset; /// /// Radius of the sphere to use for spherecasting. /// If 0 or less, rays are used instead - this may be faster, especially for complex environments. /// public float castRadius; /// /// Transform of the GameObject. /// public Transform transform; /// /// Whether to perform the casts in 2D or 3D. /// public RayPerceptionCastType castType; /// /// Filtering options for the casts. /// public int layerMask; /// /// Returns the expected number of floats in the output. /// /// public int OutputSize() { return (detectableTags.Count + 2) * angles.Count; } /// /// Get the cast start and end points for the given ray index/ /// /// /// A tuple of the start and end positions in world space. public (Vector3 StartPositionWorld, Vector3 EndPositionWorld) RayExtents(int rayIndex) { var angle = angles[rayIndex]; Vector3 startPositionLocal, endPositionLocal; if (castType == RayPerceptionCastType.Cast3D) { startPositionLocal = new Vector3(0, startOffset, 0); endPositionLocal = PolarToCartesian3D(rayLength, angle); endPositionLocal.y += endOffset; } else { // Vector2s here get converted to Vector3s (and back to Vector2s for casting) startPositionLocal = new Vector2(); endPositionLocal = PolarToCartesian2D(rayLength, angle); } var startPositionWorld = transform.TransformPoint(startPositionLocal); var endPositionWorld = transform.TransformPoint(endPositionLocal); return (StartPositionWorld: startPositionWorld, EndPositionWorld: endPositionWorld); } /// /// Converts polar coordinate to cartesian coordinate. /// static internal Vector3 PolarToCartesian3D(float radius, float angleDegrees) { var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); return new Vector3(x, 0f, z); } /// /// Converts polar coordinate to cartesian coordinate. /// static internal Vector2 PolarToCartesian2D(float radius, float angleDegrees) { var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); return new Vector2(x, y); } } /// /// Contains the data generated/produced from a ray perception sensor. /// public class RayPerceptionOutput { /// /// Contains the data generated from a single ray of a ray perception sensor. /// public struct RayOutput { /// /// Whether or not the ray hit anything. /// public bool hasHit; /// /// Whether or not the ray hit an object whose tag is in the input's detectableTags list. /// public bool hitTaggedObject; /// /// The index of the hit object's tag in the detectableTags list, or -1 if there was no hit, or the /// hit object has a different tag. /// public int hitTagIndex; /// /// Normalized distance to the hit object. /// public float hitFraction; /// /// Writes the ray output information to a subset of the float array. Each element in the rayAngles array /// determines a sublist of data to the observation. The sublist contains the observation data for a single cast. /// The list is composed of the following: /// 1. A one-hot encoding for detectable tags. For example, if detectableTags.Length = n, the /// first n elements of the sublist will be a one-hot encoding of the detectableTag that was hit, or /// all zeroes otherwise. /// 2. The 'numDetectableTags' element of the sublist will be 1 if the ray missed everything, or 0 if it hit /// something (detectable or not). /// 3. The 'numDetectableTags+1' element of the sublist will contain the normalized distance to the object /// hit, or 1.0 if nothing was hit. /// /// /// /// Output buffer. The size must be equal to (numDetectableTags+2) * rayOutputs.Length public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer) { var bufferOffset = (numDetectableTags + 2) * rayIndex; if (hitTaggedObject) { buffer[bufferOffset + hitTagIndex] = 1f; } buffer[bufferOffset + numDetectableTags] = hasHit ? 0f : 1f; buffer[bufferOffset + numDetectableTags + 1] = hitFraction; } } /// /// RayOutput for each ray that was cast. /// public RayOutput[] rayOutputs; } /// /// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent. /// internal class DebugDisplayInfo { public struct RayInfo { public Vector3 worldStart; public Vector3 worldEnd; public float castRadius; public RayPerceptionOutput.RayOutput rayOutput; } public void Reset() { m_Frame = Time.frameCount; } /// /// "Age" of the results in number of frames. This is used to adjust the alpha when drawing. /// public int age { get { return Time.frameCount - m_Frame; } } public RayInfo[] rayInfos; int m_Frame; } /// /// A sensor implementation that supports ray cast-based observations. /// public class RayPerceptionSensor : ISensor { float[] m_Observations; int[] m_Shape; string m_Name; RayPerceptionInput m_RayPerceptionInput; DebugDisplayInfo m_DebugDisplayInfo; internal DebugDisplayInfo debugDisplayInfo { get { return m_DebugDisplayInfo; } } /// /// Creates the RayPerceptionSensor. /// /// The name of the sensor. /// The inputs for the sensor. public RayPerceptionSensor(string name, RayPerceptionInput rayInput) { var numObservations = rayInput.OutputSize(); m_Shape = new[] { numObservations }; m_Name = name; m_RayPerceptionInput = rayInput; m_Observations = new float[numObservations]; if (Application.isEditor) { m_DebugDisplayInfo = new DebugDisplayInfo(); } } internal void SetRayPerceptionInput(RayPerceptionInput input) { // TODO make sure that number of rays and tags don't change m_RayPerceptionInput = input; } /// /// Computes the ray perception observations and saves them to the provided /// . /// /// Where the ray perception observations are written to. /// public int Write(WriteAdapter adapter) { using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive")) { Array.Clear(m_Observations, 0, m_Observations.Length); var numRays = m_RayPerceptionInput.angles.Count; var numDetectableTags = m_RayPerceptionInput.detectableTags.Count; if (m_DebugDisplayInfo != null) { // Reset the age information, and resize the buffer if needed. m_DebugDisplayInfo.Reset(); if (m_DebugDisplayInfo.rayInfos == null || m_DebugDisplayInfo.rayInfos.Length != numRays) { m_DebugDisplayInfo.rayInfos = new DebugDisplayInfo.RayInfo[numRays]; } } // For each ray, do the casting, and write the information to the observation buffer for (var rayIndex = 0; rayIndex < numRays; rayIndex++) { DebugDisplayInfo.RayInfo debugRay; var rayOutput = PerceiveSingleRay(m_RayPerceptionInput, rayIndex, out debugRay); if (m_DebugDisplayInfo != null) { m_DebugDisplayInfo.rayInfos[rayIndex] = debugRay; } rayOutput.ToFloatArray(numDetectableTags, rayIndex, m_Observations); } // Finally, add the observations to the WriteAdapter adapter.AddRange(m_Observations); } return m_Observations.Length; } /// public void Update() { } /// public int[] GetObservationShape() { return m_Shape; } /// public string GetName() { return m_Name; } /// public virtual byte[] GetCompressedObservation() { return null; } /// public virtual SensorCompressionType GetCompressionType() { return SensorCompressionType.None; } /// /// Evaluates the raycasts to be used as part of an observation of an agent. /// /// Input defining the rays that will be cast. /// Output struct containing the raycast results. public static RayPerceptionOutput Perceive(RayPerceptionInput input) { RayPerceptionOutput output = new RayPerceptionOutput(); output.rayOutputs = new RayPerceptionOutput.RayOutput[input.angles.Count]; for (var rayIndex = 0; rayIndex < input.angles.Count; rayIndex++) { DebugDisplayInfo.RayInfo debugRay; output.rayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex, out debugRay); } return output; } /// /// Evaluate the raycast results of a single ray from the RayPerceptionInput. /// /// /// /// /// internal static RayPerceptionOutput.RayOutput PerceiveSingleRay( RayPerceptionInput input, int rayIndex, out DebugDisplayInfo.RayInfo debugRayOut ) { var unscaledRayLength = input.rayLength; var unscaledCastRadius = input.castRadius; var extents = input.RayExtents(rayIndex); var startPositionWorld = extents.StartPositionWorld; var endPositionWorld = extents.EndPositionWorld; var rayDirection = endPositionWorld - startPositionWorld; // If there is non-unity scale, |rayDirection| will be different from rayLength. // We want to use this transformed ray length for determining cast length, hit fraction etc. // We also it to scale up or down the sphere or circle radii var scaledRayLength = rayDirection.magnitude; // Avoid 0/0 if unscaledRayLength is 0 var scaledCastRadius = unscaledRayLength > 0 ? unscaledCastRadius * scaledRayLength / unscaledRayLength : unscaledCastRadius; // Do the cast and assign the hit information for each detectable tag. bool castHit; float hitFraction; GameObject hitObject; if (input.castType == RayPerceptionCastType.Cast3D) { RaycastHit rayHit; if (scaledCastRadius > 0f) { castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit, scaledRayLength, input.layerMask); } else { castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit, scaledRayLength, input.layerMask); } // If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?). // To avoid 0/0, set the fraction to 0. hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f; hitObject = castHit ? rayHit.collider.gameObject : null; } else { RaycastHit2D rayHit; if (scaledCastRadius > 0f) { rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection, scaledRayLength, input.layerMask); } else { rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, input.layerMask); } castHit = rayHit; hitFraction = castHit ? rayHit.fraction : 1.0f; hitObject = castHit ? rayHit.collider.gameObject : null; } var rayOutput = new RayPerceptionOutput.RayOutput { hasHit = castHit, hitFraction = hitFraction, hitTaggedObject = false, hitTagIndex = -1 }; if (castHit) { // Find the index of the tag of the object that was hit. for (var i = 0; i < input.detectableTags.Count; i++) { if (hitObject.CompareTag(input.detectableTags[i])) { rayOutput.hitTaggedObject = true; rayOutput.hitTagIndex = i; break; } } } debugRayOut.worldStart = startPositionWorld; debugRayOut.worldEnd = endPositionWorld; debugRayOut.rayOutput = rayOutput; debugRayOut.castRadius = scaledCastRadius; return rayOutput; } } }