namespace MLAgents
{
public class RayPerceptionSensor : ISensor
/// <summary>
/// Determines which dimensions the sensor will perform the casts in.
/// </summary>
public enum RayPerceptionCastType
{
/// Cast in 2 dimensions, using Physics2D.CircleCast or Physics2D.RayCast.
Cast2D ,
/// Cast in 3 dimensions, using Physics.SphereCast or Physics.RayCast.
Cast3D ,
}
public struct RayPerceptionInput
public enum CastType
{
Cast2D ,
Cast3D ,
}
/// <summary>
/// Length of the rays to cast. This will be scaled up or down based on the scale of the transform.
/// </summary>
public float rayLength ;
/// <summary>
/// List of tags which correspond to object types agent can see.
/// </summary>
public IReadOnlyList < string > detectableTags ;
/// <summary>
/// List of angles (in degrees) used to define the rays.
/// 90 degrees is considered "forward" relative to the game object.
/// </summary>
public IReadOnlyList < float > angles ;
/// <summary>
/// Starting height offset of ray from center of agent
/// </summary>
public float startOffset ;
float [ ] m_Observations ;
int [ ] m_Shape ;
string m_Name ;
/// <summary>
/// Ending height offset of ray from center of agent.
/// </summary>
public float endOffset ;
float m_RayDistance ;
List < string > m_DetectableObjects ;
float [ ] m_Angles ;
/// <summary>
/// Radius of the sphere to use for spherecasting.
/// If 0 or less, rays are used instead - this may be faster, especially for complex environments.
/// </summary>
public float castRadius ;
float m_StartOffset ;
float m_EndOffset ;
float m_CastRadius ;
CastType m_CastType ;
Transform m_Transform ;
int m_LayerMask ;
/// <summary>
/// Transform of the GameObject.
/// </summary>
public Transform transform ;
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// Whether to perform the casts in 2D or 3D.
public class DebugDisplayInfo
public RayPerceptionCastType castType ;
/// <summary>
/// Filtering options for the casts.
/// </summary>
public int layerMask ;
/// <summary>
/// Returns the expected number of floats in the output.
/// </summary>
/// <returns></returns>
public int OutputSize ( )
public struct RayInfo
return ( detectableTags . Count + 2 ) * angles . Count ;
}
/// <summary>
/// Get the cast start and end points for the given ray index/
/// </summary>
/// <param name="rayIndex"></param>
/// <returns>A tuple of the start and end positions in world space.</returns>
public ( Vector3 StartPositionWorld , Vector3 EndPositionWorld ) RayExtents ( int rayIndex )
{
var angle = angles [ rayIndex ] ;
Vector3 startPositionLocal , endPositionLocal ;
if ( castType = = RayPerceptionCastType . Cast3D )
public Vector3 localStart ;
public Vector3 localEnd ;
public Vector3 worldStart ;
public Vector3 worldEnd ;
public bool castHit ;
public float hitFraction ;
public float castRadius ;
startPositionLocal = new Vector3 ( 0 , startOffset , 0 ) ;
endPositionLocal = PolarToCartesian3D ( rayLength , angle ) ;
endPositionLocal . y + = endOffset ;
public void Reset ( )
else
m_Frame = Time . frameCount ;
// Vector2s here get converted to Vector3s (and back to Vector2s for casting)
startPositionLocal = new Vector2 ( ) ;
endPositionLocal = PolarToCartesian2D ( rayLength , angle ) ;
var startPositionWorld = transform . TransformPoint ( startPositionLocal ) ;
var endPositionWorld = transform . TransformPoint ( endPositionLocal ) ;
return ( StartPositionWorld : startPositionWorld , EndPositionWorld : endPositionWorld ) ;
}
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static internal Vector3 PolarToCartesian3D ( float radius , float angleDegrees )
{
var x = radius * Mathf . Cos ( Mathf . Deg2Rad * angleDegrees ) ;
var z = radius * Mathf . Sin ( Mathf . Deg2Rad * angleDegrees ) ;
return new Vector3 ( x , 0f , z ) ;
}
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static internal Vector2 PolarToCartesian2D ( float radius , float angleDegrees )
{
var x = radius * Mathf . Cos ( Mathf . Deg2Rad * angleDegrees ) ;
var y = radius * Mathf . Sin ( Mathf . Deg2Rad * angleDegrees ) ;
return new Vector2 ( x , y ) ;
}
}
public class RayPerceptionOutput
{
public struct RayOutput
{
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// Whether or not the ray hit anything.
/// </summary>
public bool hasHit ;
/// <summary>
/// Whether or not the ray hit an object whose tag is in the input's detectableTags list.
/// </summary>
public bool hitTaggedObject ;
/// <summary>
/// The index of the hit object's tag in the detectableTags list, or -1 if there was no hit, or the
/// hit object has a different tag.
/// </summary>
public int hitTagIndex ;
/// <summary>
/// Normalized distance to the hit object.
/// </summary>
public float hitFraction ;
/// <summary>
/// Writes the ray output information to a subset of the float array. Each element in the rayAngles array
/// determines a sublist of data to the observation. The sublist contains the observation data for a single cast.
/// The list is composed of the following:
/// 1. A one-hot encoding for detectable tags. For example, if detectableTags.Length = n, the
/// first n elements of the sublist will be a one-hot encoding of the detectableTag that was hit, or
/// all zeroes otherwise.
/// 2. The 'numDetectableTags' element of the sublist will be 1 if the ray missed everything, or 0 if it hit
/// something (detectable or not).
/// 3. The 'numDetectableTags+1' element of the sublist will contain the normalized distance to the object
/// hit, or 1.0 if nothing was hit.
public int age
/// <param name="numDetectableTags"></param>
/// <param name="rayIndex"></param>
/// <param name="buffer">Output buffer. The size must be equal to (numDetectableTags+2) * rayOutputs.Length</param>
public void ToFloatArray ( int numDetectableTags , int rayIndex , float [ ] buffer )
get { return Time . frameCount - m_Frame ; }
var bufferOffset = ( numDetectableTags + 2 ) * rayIndex ;
if ( hitTaggedObject )
{
buffer [ bufferOffset + hitTagIndex ] = 1f ;
}
buffer [ bufferOffset + numDetectableTags ] = hasHit ? 0f : 1f ;
buffer [ bufferOffset + numDetectableTags + 1 ] = hitFraction ;
}
/// <summary>
/// RayOutput for each ray that was cast.
/// </summary>
public RayOutput [ ] rayOutputs ;
}
public RayInfo [ ] rayInfos ;
/// <summary>
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// </summary>
internal class DebugDisplayInfo
{
public struct RayInfo
{
public Vector3 worldStart ;
public Vector3 worldEnd ;
public float castRadius ;
public RayPerceptionOutput . RayOutput rayOutput ;
}
public void Reset ( )
{
m_Frame = Time . frameCount ;
}
int m_Frame ;
/// <summary>
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// </summary>
public int age
{
get { return Time . frameCount - m_Frame ; }
public RayInfo [ ] rayInfos ;
int m_Frame ;
}
public class RayPerceptionSensor : ISensor
{
float [ ] m_Observations ;
int [ ] m_Shape ;
string m_Name ;
RayPerceptionInput m_RayPerceptionInput ;
public DebugDisplayInfo debugDisplayInfo
internal DebugDisplayInfo debugDisplayInfo
public RayPerceptionSensor ( string name , float rayDistance , List < string > detectableObjects , float [ ] angles ,
Transform transform , float startOffset , float endOffset , float castRadius , CastType castType ,
int rayLayerMask )
public RayPerceptionSensor ( string name , RayPerceptionInput rayInput )
var numObservations = ( detectableObjects . Count + 2 ) * angles . Length ;
var numObservations = rayInput . OutputSize ( ) ;
m_RayPerceptionInput = rayInput ;
m_RayDistance = rayDistance ;
m_DetectableObjects = detectableObjects ;
// TODO - preprocess angles, save ray directions instead?
m_Angles = angles ;
m_Transform = transform ;
m_StartOffset = startOffset ;
m_EndOffset = endOffset ;
m_CastRadius = castRadius ;
m_CastType = castType ;
m_LayerMask = rayLayerMask ;
if ( Application . isEditor )
{
{
using ( TimerStack . Instance . Scoped ( "RayPerceptionSensor.Perceive" ) )
{
PerceiveStatic (
m_RayDistance , m_Angles , m_DetectableObjects , m_StartOffset , m_EndOffset ,
m_CastRadius , m_Transform , m_CastType , m_Observations , m_LayerMask ,
m_DebugDisplayInfo
) ;
Array . Clear ( m_Observations , 0 , m_Observations . Length ) ;
var numRays = m_RayPerceptionInput . angles . Count ;
var numDetectableTags = m_RayPerceptionInput . detectableTags . Count ;
if ( m_DebugDisplayInfo ! = null )
{
// Reset the age information, and resize the buffer if needed.
m_DebugDisplayInfo . Reset ( ) ;
if ( m_DebugDisplayInfo . rayInfos = = null | | m_DebugDisplayInfo . rayInfos . Length ! = numRays )
{
m_DebugDisplayInfo . rayInfos = new DebugDisplayInfo . RayInfo [ numRays ] ;
}
}
// For each ray, do the casting, and write the information to the observation buffer
for ( var rayIndex = 0 ; rayIndex < numRays ; rayIndex + + )
{
DebugDisplayInfo . RayInfo debugRay ;
var rayOutput = PerceiveSingleRay ( m_RayPerceptionInput , rayIndex , out debugRay ) ;
if ( m_DebugDisplayInfo ! = null )
{
m_DebugDisplayInfo . rayInfos [ rayIndex ] = debugRay ;
}
rayOutput . ToFloatArray ( numDetectableTags , rayIndex , m_Observations ) ;
}
// Finally, add the observations to the WriteAdapter
adapter . AddRange ( m_Observations ) ;
}
return m_Observations . Length ;
}
/// <summary>
/// Evaluates a perception vector to be used as part of an observation of an agent.
/// Each element in the rayAngles array determines a sublist of data to the observation.
/// The sublist contains the observation data for a single cast. The list is composed of the following:
/// 1. A one-hot encoding for detectable objects. For example, if detectableObjects.Length = n, the
/// first n elements of the sublist will be a one-hot encoding of the detectableObject that was hit, or
/// all zeroes otherwise.
/// 2. The 'length' element of the sublist will be 1 if the ray missed everything, or 0 if it hit
/// something (detectable or not).
/// 3. The 'length+1' element of the sublist will contain the normalised distance to the object hit, or 1 if
/// nothing was hit.
///
/// Evaluates the raycasts to be used as part of an observation of an agent.
/// <param name="unscaledRayLength"></param>
/// <param name="rayAngles">List of angles (in degrees) used to define the rays. 90 degrees is considered
/// "forward" relative to the game object</param>
/// <param name="detectableObjects">List of tags which correspond to object types agent can see</param>
/// <param name="startOffset">Starting height offset of ray from center of agent.</param>
/// <param name="endOffset">Ending height offset of ray from center of agent.</param>
/// <param name="unscaledCastRadius">Radius of the sphere to use for spherecasting. If 0 or less, rays are used
/// instead - this may be faster, especially for complex environments.</param>
/// <param name="transform">Transform of the GameObject</param>
/// <param name="castType">Whether to perform the casts in 2D or 3D.</param>
/// <param name="perceptionBuffer">Output array of floats. Must be (num rays) * (num tags + 2) in size.</param>
/// <param name="layerMask">Filtering options for the casts</param>
/// <param name="debugInfo">Optional debug information output, only used by RayPerceptionSensor.</param>
/// <param name="input">Input defining the rays that will be cast.</param>
/// <param name="output">Output class that will be written to with raycast results.</param>
public static void PerceiveStatic ( float unscaledRayLength ,
IReadOnlyList < float > rayAngles , IReadOnlyList < string > detectableObjects ,
float startOffset , float endOffset , float unscaledCastRadius ,
Transform transform , CastType castType , float [ ] perceptionBuffer ,
int layerMask = Physics . DefaultRaycastLayers ,
DebugDisplayInfo debugInfo = null )
public static RayPerceptionOutput PerceiveStatic ( RayPerceptionInput input )
Array . Clear ( perceptionBuffer , 0 , perceptionBuffer . Length ) ;
if ( debugInfo ! = null )
RayPerceptionOutput output = new RayPerceptionOutput ( ) ;
output . rayOutputs = new RayPerceptionOutput . RayOutput [ input . angles . Count ] ;
for ( var rayIndex = 0 ; rayIndex < input . angles . Count ; rayIndex + + )
debugInfo . Reset ( ) ;
if ( debugInfo . rayInfos = = null | | debugInfo . rayInfos . Length ! = rayAngles . Count )
{
debugInfo . rayInfos = new DebugDisplayInfo . RayInfo [ rayAngles . Count ] ;
}
DebugDisplayInfo . RayInfo debugRay ;
output . rayOutputs [ rayIndex ] = PerceiveSingleRay ( input , rayIndex , out debugRay ) ;
// For each ray sublist stores categorical information on detected object
// along with object distance.
int bufferOffset = 0 ;
for ( var rayIndex = 0 ; rayIndex < rayAngles . Count ; rayIndex + + )
{
var angle = rayAngles [ rayIndex ] ;
Vector3 startPositionLocal , endPositionLocal ;
if ( castType = = CastType . Cast3D )
{
startPositionLocal = new Vector3 ( 0 , startOffset , 0 ) ;
endPositionLocal = PolarToCartesian3D ( unscaledRayLength , angle ) ;
endPositionLocal . y + = endOffset ;
}
else
{
// Vector2s here get converted to Vector3s (and back to Vector2s for casting)
startPositionLocal = new Vector2 ( ) ;
endPositionLocal = PolarToCartesian2D ( unscaledRayLength , angle ) ;
}
return output ;
}
var startPositionWorld = transform . TransformPoint ( startPositionLocal ) ;
var endPositionWorld = transform . TransformPoint ( endPositionLocal ) ;
/// <summary>
/// Evaluate the raycast results of a single ray from the RayPerceptionInput.
/// </summary>
/// <param name="input"></param>
/// <param name="rayIndex"></param>
/// <param name="debugRayOut"></param>
/// <returns></returns>
static RayPerceptionOutput . RayOutput PerceiveSingleRay (
RayPerceptionInput input ,
int rayIndex ,
out DebugDisplayInfo . RayInfo debugRayOut
)
{
var unscaledRayLength = input . rayLength ;
var unscaledCastRadius = input . castRadius ;
var rayDirection = endPositionWorld - startPositionWorld ;
// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection . magnitude ;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ? unscaledCastRadius * scaledRayLength / unscaledRayLength : unscaledCastRadius ;
var extents = input . RayExtents ( rayIndex ) ;
var startPositionWorld = extents . StartPositionWorld ;
var endPositionWorld = extents . EndPositionWorld ;
// Do the cast and assign the hit information for each detectable object.
// sublist[0 ] <- did hit detectableObjects[0]
// ...
// sublist[numObjects-1] <- did hit detectableObjects[numObjects-1]
// sublist[numObjects ] <- 1 if missed else 0
// sublist[numObjects+1] <- hit fraction (or 1 if no hit)
var rayDirection = endPositionWorld - startPositionWorld ;
// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection . magnitude ;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ?
unscaledCastRadius * scaledRayLength / unscaledRayLength :
unscaledCastRadius ;
bool castHit ;
float hitFraction ;
GameObject hitObject ;
// Do the cast and assign the hit information for each detectable tag.
bool castHit ;
float hitFraction ;
GameObject hitObject ;
if ( castType = = CastType . Cast3D )
if ( input . castType = = RayPerceptionCastType . Cast3D )
{
RaycastHit rayHit ;
if ( scaledCastRadius > 0f )
RaycastHit rayHit ;
if ( scaledCastRadius > 0f )
{
castHit = Physics . SphereCast ( startPositionWorld , scaledCastRadius , rayDirection , out rayHit ,
scaledRayLength , layerMask ) ;
}
else
{
castHit = Physics . Raycast ( startPositionWorld , rayDirection , out rayHit ,
scaledRayLength , layerMask ) ;
}
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? ( scaledRayLength > 0 ? rayHit . distance / scaledRayLength : 0.0f ) : 1.0f ;
hitObject = castHit ? rayHit . collider . gameObject : null ;
castHit = Physics . SphereCast ( startPositionWorld , scaledCastRadius , rayDirection , out rayHit ,
scaledRayLength , input . layerMask ) ;
RaycastHit2D rayHit ;
if ( scaledCastRadius > 0f )
{
rayHit = Physics2D . CircleCast ( startPositionWorld , scaledCastRadius , rayDirection ,
scaledRayLength , layerMask ) ;
}
else
{
rayHit = Physics2D . Raycast ( startPositionWorld , rayDirection , scaledRayLength , layerMask ) ;
}
castHit = rayHit ;
hitFraction = castHit ? rayHit . fraction : 1.0f ;
hitObject = castHit ? rayHit . collider . gameObject : null ;
castHit = Physics . Raycast ( startPositionWorld , rayDirection , out rayHit ,
scaledRayLength , input . layerMask ) ;
if ( debugInfo ! = null )
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? ( scaledRayLength > 0 ? rayHit . distance / scaledRayLength : 0.0f ) : 1.0f ;
hitObject = castHit ? rayHit . collider . gameObject : null ;
}
else
{
RaycastHit2D rayHit ;
if ( scaledCastRadius > 0f )
debugInfo . rayInfos [ rayIndex ] . localStart = startPositionLocal ;
debugInfo . rayInfos [ rayIndex ] . localEnd = endPositionLocal ;
debugInfo . rayInfos [ rayIndex ] . worldStart = startPositionWorld ;
debugInfo . rayInfos [ rayIndex ] . worldEnd = endPositionWorld ;
debugInfo . rayInfos [ rayIndex ] . castHit = castHit ;
debugInfo . rayInfos [ rayIndex ] . hitFraction = hitFraction ;
debugInfo . rayInfos [ rayIndex ] . castRadius = scaledCastRadius ;
rayHit = Physics2D . CircleCast ( startPositionWorld , scaledCastRadius , rayDirection ,
scaledRayLength , input . layerMask ) ;
else if ( Application . isEditor )
else
// Legacy drawing
Debug . DrawRay ( startPositionWorld , rayDirection , Color . black , 0.01f , true ) ;
rayHit = Physics2D . Raycast ( startPositionWorld , rayDirection , scaledRayLength , input . layerMask ) ;
if ( castHit )
castHit = rayHit ;
hitFraction = castHit ? rayHit . fraction : 1.0f ;
hitObject = castHit ? rayHit . collider . gameObject : null ;
}
var rayOutput = new RayPerceptionOutput . RayOutput
{
hasHit = castHit ,
hitFraction = hitFraction ,
hitTaggedObject = false ,
hitTagIndex = - 1
} ;
if ( castHit )
{
// Find the index of the tag of the object that was hit.
for ( var i = 0 ; i < input . detectableTags . Count ; i + + )
bool hitTaggedObject = false ;
for ( var i = 0 ; i < detectableObjects . Count ; i + + )
if ( hitObject . CompareTag ( input . detectableTags [ i ] ) )
if ( hitObject . CompareTag ( detectableObjects [ i ] ) )
{
perceptionBuffer [ bufferOffset + i ] = 1 ;
perceptionBuffer [ bufferOffset + detectableObjects . Count + 1 ] = hitFraction ;
hitTaggedObject = true ;
break ;
}
}
if ( ! hitTaggedObject )
{
// Something was hit but not on the list. Still set the hit fraction.
perceptionBuffer [ bufferOffset + detectableObjects . Count + 1 ] = hitFraction ;
rayOutput . hitTaggedObject = true ;
rayOutput . hitTagIndex = i ;
break ;
else
{
perceptionBuffer [ bufferOffset + detectableObjects . Count ] = 1f ;
// Nothing was hit, so there's full clearance in front of the agent.
perceptionBuffer [ bufferOffset + detectableObjects . Count + 1 ] = 1.0f ;
}
}
bufferOffset + = detectableObjects . Count + 2 ;
}
}
debugRayOut . worldStart = startPositionWorld ;
debugRayOut . worldEnd = endPositionWorld ;
debugRayOut . rayOutput = rayOutput ;
debugRayOut . castRadius = scaledCastRadius ;
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static Vector3 PolarToCartesian3D ( float radius , float angleDegrees )
{
var x = radius * Mathf . Cos ( Mathf . Deg2Rad * angleDegrees ) ;
var z = radius * Mathf . Sin ( Mathf . Deg2Rad * angleDegrees ) ;
return new Vector3 ( x , 0f , z ) ;
}
return rayOutput ;
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static Vector2 PolarToCartesian2D ( float radius , float angleDegrees )
{
var x = radius * Mathf . Cos ( Mathf . Deg2Rad * angleDegrees ) ;
var y = radius * Mathf . Sin ( Mathf . Deg2Rad * angleDegrees ) ;
return new Vector2 ( x , y ) ;
}
}
}