using UnityEngine ;
using UnityEngine.Assertions ;
using Unity.MLAgents.Sensors ;
using UnityEngine.Profiling ;
namespace Unity.MLAgents.Extensions.Sensors
{
[Tooltip("The reference of the root of the agent. This is used to disambiguate objects with the same tag as the agent. Defaults to current GameObject")]
public GameObject rootReference ;
[Header("Collider Buffer Properties")]
[ Tooltip ( "The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words" +
" the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell." ) ]
public int MaxColliderBufferSize = 5 0 0 ;
[ Tooltip (
"The Estimated Max Number of Colliders to expect per cell. This number is used to " +
"pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc " +
"Physics API. If the number of colliders found is >= InitialColliderBufferSize the array " +
"will be resized to double its current size. The hard coded absolute size is 500." ) ]
public int InitialColliderBufferSize = 4 ;
Collider [ ] m_ColliderBuffer ;
float [ ] m_ChannelBuffer ;
//
// Hidden Parameters
//
/// <summary>
/// Radius of grid, used for normalizing the distance.
/// </summary>
protected float SphereRadius ;
protected float Inverse SphereRadius;
/// <summary>
/// Total Number of cells (width*height)
NumCells = GridNumSideX * GridNumSideZ ;
float sphereRadiusX = ( CellScaleX * GridNumSideX ) / Mathf . Sqrt ( 2 ) ;
float sphereRadiusZ = ( CellScaleZ * GridNumSideZ ) / Mathf . Sqrt ( 2 ) ;
SphereRadius = Mathf . Max ( sphereRadiusX , sphereRadiusZ ) ;
Inverse SphereRadius = 1.0f / Mathf . Max ( sphereRadiusX , sphereRadiusZ ) ;
ChannelOffsets = new int [ ChannelDepth . Length ] ;
DiffNumSideZX = ( GridNumSideZ - GridNumSideX ) ;
OffsetGridNumSide = ( GridNumSideZ - 1f ) / 2f ;
InitDepthType ( ) ;
InitCellPoints ( ) ;
InitPerceptionBuffer ( ) ;
m_ColliderBuffer = new Collider [ Math . Min ( MaxColliderBufferSize , InitialColliderBufferSize ) ] ;
// Default root reference to current game object
if ( rootReference = = null )
rootReference = gameObject ;
m_perceptionTexture2D = new Texture2D ( GridNumSideX , GridNumSideZ , TextureFormat . RGB24 , false ) ;
}
/// <inheritdoc cref="ISensor.Reset"/>
void ISensor . Reset ( ) { }
public void Reset ( )
public void ClearPerceptionBuffer ( )
{
if ( m_PerceptionBuffer ! = null )
{
else
{
m_PerceptionBuffer = new float [ NumberOfObservations ] ;
m_ColliderBuffer = new Collider [ Math . Min ( MaxColliderBufferSize , InitialColliderBufferSize ) ] ;
}
if ( ShowGizmos )
/// <returns>A float[] containing all of the information collected from the gridsensor</returns>
public float [ ] Perceive ( )
{
Reset ( ) ;
if ( m_ColliderBuffer = = null )
{
return Array . Empty < float > ( ) ;
}
ClearPerceptionBuffer ( ) ;
// TODO: make these part of the class
Collider [ ] foundColliders = null ;
Vector3 cellCenter = Vector3 . zero ;
var halfCellScale = new Vector3 ( CellScaleX / 2f , CellScaleY , CellScaleZ / 2f ) ;
Vector3 halfCellScale = new Vector3 ( CellScaleX / 2f , CellScaleY , CellScaleZ / 2f ) ;
for ( int cellIndex = 0 ; cellIndex < NumCells ; cellIndex + + )
for ( var cellIndex = 0 ; cellIndex < NumCells ; cellIndex + + )
int numFound ;
Vector3 cellCenter ;
cellCenter = transform . TransformPoint ( CellPoints [ cellIndex ] ) ;
foundColliders = Physics . OverlapBox ( cellCenter , halfCellScale , transform . rotation , ObserveMask ) ;
Transform transform1 ;
cellCenter = ( transform1 = transform ) . TransformPoint ( CellPoints [ cellIndex ] ) ;
numFound = BufferResizingOverlapBoxNonAlloc ( cellCenter , halfCellScale , transform1 . rotation ) ;
foundColliders = Physics . OverlapBox ( cellCenter , halfCellScale , Quaternion . identity , ObserveMask ) ;
numFound = BufferResizingOverlapBoxNonAlloc ( cellCenter , halfCellScale , Quaternion . identity ) ;
if ( foundColliders ! = null & & foundColliders . Length > 0 )
if ( numFound > 0 )
ParseColliders ( foundColliders , cellIndex , cellCenter ) ;
ParseColliders ( m_ColliderBuffer , numFound , cellIndex , cellCenter ) ;
/// This method attempts to perform the Physics.OverlapBoxNonAlloc and will double the size of the Collider buffer
/// if the number of Colliders in the buffer after the call is equal to the length of the buffer.
/// </summary>
/// <param name="cellCenter"></param>
/// <param name="halfCellScale"></param>
/// <param name="rotation"></param>
/// <returns></returns>
int BufferResizingOverlapBoxNonAlloc ( Vector3 cellCenter , Vector3 halfCellScale , Quaternion rotation )
{
int numFound ;
// Since we can only get a fixed number of results, requery
// until we're sure we can hold them all (or until we hit the max size).
while ( true )
{
numFound = Physics . OverlapBoxNonAlloc ( cellCenter , halfCellScale , m_ColliderBuffer , rotation , ObserveMask ) ;
if ( numFound = = m_ColliderBuffer . Length & & m_ColliderBuffer . Length < MaxColliderBufferSize )
{
m_ColliderBuffer = new Collider [ Math . Min ( MaxColliderBufferSize , m_ColliderBuffer . Length * 2 ) ] ;
InitialColliderBufferSize = m_ColliderBuffer . Length ;
}
else
{
break ;
}
}
return numFound ;
}
/// <summary>
/// <param name="numFound">Number of colliders found.</param>
protected virtual void ParseColliders ( Collider [ ] foundColliders , int cellIndex , Vector3 cellCenter )
protected virtual void ParseColliders ( Collider [ ] foundColliders , int numFound , int cellIndex , Vector3 cellCenter )
GameObject currentColliderGo = null ;
Profiler . BeginSample ( "GridSensor.ParseColliders" ) ;
Vector3 closestColliderPoint = Vector3 . zero ;
float distance = float . MaxValue ;
float currentDistance = 0f ;
var minDistanceSquared = float . MaxValue ;
for ( int i = 0 ; i < foundColliders . Length ; i + + )
for ( var i = 0 ; i < numFound ; i + + )
currentColliderGo = foundColliders [ i ] . gameObject ;
var currentColliderGo = foundColliders [ i ] . gameObject ;
if ( currentColliderGo = = rootReference )
if ( ReferenceEquals ( currentColliderGo , rootReference ) )
closestColliderPoint = foundColliders [ i ] . ClosestPointOnBounds ( cellCenter ) ;
currentDistance = Vector3 . Distance ( closestColliderPoint , rootReference . transform . position ) ;
var closestColliderPoint = foundColliders [ i ] . ClosestPointOnBounds ( cellCenter ) ;
var currentDistanceSquared = ( closestColliderPoint - rootReference . transform . position ) . sqrMagnitude ;
if ( ( Array . IndexOf ( DetectableObjects , currentColliderGo . tag ) > - 1 ) & & ( currentDistance < distance ) )
var index = - 1 ;
for ( var ii = 0 ; ii < DetectableObjects . Length ; ii + + )
{
if ( currentColliderGo . CompareTag ( DetectableObjects [ ii ] ) )
{
index = ii ;
break ;
}
}
if ( index > - 1 & & currentDistanceSquared < minDistanceSquared )
distance = currentDistance ;
minDistanceSquared = currentDistanceSquared ;
if ( closestColliderGo ! = null )
LoadObjectData ( closestColliderGo , cellIndex , distance / SphereRadius ) ;
if ( ! ReferenceEquals ( closestColliderGo , null ) )
LoadObjectData ( closestColliderGo , cellIndex , ( float ) Math . Sqrt ( minDistanceSquared ) * InverseSphereRadius ) ;
Profiler . EndSample ( ) ;
}
/// <summary>
/// </example>
protected virtual float [ ] GetObjectData ( GameObject currentColliderGo , float typeIndex , float normalizedDistance )
{
float [ ] channelValues = new float [ ChannelDepth . Length ] ;
channelValues [ 0 ] = typeIndex ;
return channelValues ;
if ( m_ChannelBuffer = = null )
{
m_ChannelBuffer = new float [ ChannelDepth . Length ] ;
}
Array . Clear ( m_ChannelBuffer , 0 , m_ChannelBuffer . Length ) ;
m_ChannelBuffer [ 0 ] = typeIndex ;
return m_ChannelBuffer ;
}
/// <summary>
/// </summary>
/// <param name="currentColliderGo">The game object that was found colliding with a certain cell</param>
/// <param name="cellIndex">The index of the current cell</param>
/// <param name="normalized_distance">A float between 0 and 1 describing the ratio of
/// <param name="normalizedDistance">A float between 0 and 1 describing the ratio of
protected virtual void LoadObjectData ( GameObject currentColliderGo , int cellIndex , float normalized_distance )
protected virtual void LoadObjectData ( GameObject currentColliderGo , int cellIndex , float normalizedDistance )
for ( int i = 0 ; i < DetectableObjects . Length ; i + + )
Profiler . BeginSample ( "GridSensor.LoadObjectData" ) ;
var channelHotVals = new ArraySegment < float > ( m_PerceptionBuffer , cellIndex * ObservationPerCell , ObservationPerCell ) ;
for ( var i = 0 ; i < DetectableObjects . Length ; i + + )
if ( currentColliderGo ! = null & & currentColliderGo . CompareTag ( DetectableObjects [ i ] ) )
for ( var ii = 0 ; ii < channelHotVals . Count ; ii + + )
{
m_PerceptionBuffer [ channelHotVals . Offset + ii ] = 0f ;
}
if ( ! ReferenceEquals ( currentColliderGo , null ) & & currentColliderGo . CompareTag ( DetectableObjects [ i ] ) )
float [ ] channelValues = GetObjectData ( currentColliderGo , ( float ) i + 1 , normalized_distance ) ;
float [ ] channelValues = GetObjectData ( currentColliderGo , ( float ) i + 1 , normalizedDistance ) ;
if ( ShowGizmos )
{
Color debugRayColor = Color . white ;
}
CellActivity [ cellIndex ] = new Color ( debugRayColor . r , debugRayColor . g , debugRayColor . b , . 5f ) ;
}
/// <remarks>
/// The observations are "channel based" so each grid is WxHxC where C is the number of channels
/// This typically means that each channel value is normalized between 0 and 1
/// If channelDepth is 1, the value is assumed normalized, else the value is normalized by the channelDepth
/// The channels are then stored consecutively in PerceptionBuffer.
/// NOTE: This is the only grid type that uses floating point values
/// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams:
/// channelValues = {2, 1}
/// ObservationPerCell = channelValues.Length
/// channelValues = {2f/5f, 1f/3f} = {.4, .33..}
/// Array.Copy(channelValues, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
/// </remarks>
for ( int j = 0 ; j < channelValues . Length ; j + + )
channelValues [ j ] / = ChannelDepth [ j ] ;
// The observations are "channel based" so each grid is WxHxC where C is the number of channels
// This typically means that each channel value is normalized between 0 and 1
// If channelDepth is 1, the value is assumed normalized, else the value is normalized by the channelDepth
// The channels are then stored consecutively in PerceptionBuffer.
// NOTE: This is the only grid type that uses floating point values
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams:
// channelValues = {2, 1}
// ObservationPerCell = channelValues.Length
// channelValues = {2f/5f, 1f/3f} = {.4, .33..}
// Array.Copy(channelValues, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for ( int j = 0 ; j < channelValues . Length ; j + + )
{
channelValues [ j ] / = ChannelDepth [ j ] ;
}
Array . Copy ( channelValues , 0 , m_PerceptionBuffer , cellIndex * ObservationPerCell , ObservationPerCell ) ;
break ;
Array . Copy ( channelValues , 0 , m_PerceptionBuffer , cellIndex * ObservationPerCell , ObservationPerCell ) ;
break ;
/// <remarks>
/// The observations are "channel hot" so each grid is WxHxD where D is the sum of all of the channel depths
/// The opposite of the "channel based" case, the channel values are represented as one hot vector per channel and then concatenated together
/// Thus channelDepth is assumed to be greater than 1.
/// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams,
/// channelValues = {2, 1}
/// channelOffsets = {5, 3}
/// ObservationPerCell = 5 + 3 = 8
/// channelHotVals = {0, 0, 1, 0, 0, 0, 1, 0}
/// Array.Copy(channelHotVals, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
/// </remarks>
float [ ] channelHotVals = new float [ ObservationPerCell ] ;
for ( int j = 0 ; j < channelValues . Length ; j + + )
if ( ChannelDepth [ j ] > 1 )
{
channelHotVals [ ( int ) channelValues [ j ] + ChannelOffsets [ j ] ] = 1f ;
}
else
// The observations are "channel hot" so each grid is WxHxD where D is the sum of all of the channel depths
// The opposite of the "channel based" case, the channel values are represented as one hot vector per channel and then concatenated together
// Thus channelDepth is assumed to be greater than 1.
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams,
// channelValues = {2, 1}
// channelOffsets = {5, 3}
// ObservationPerCell = 5 + 3 = 8
// channelHotVals = {0, 0, 1, 0, 0, 0, 1, 0}
// Array.Copy(channelHotVals, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for ( int j = 0 ; j < channelValues . Length ; j + + )
channelHotVals [ ChannelOffsets [ j ] ] = channelValues [ j ] ;
if ( ChannelDepth [ j ] > 1 )
{
m_PerceptionBuffer [ channelHotVals . Offset + ( int ) channelValues [ j ] + ChannelOffsets [ j ] ] = 1f ;
}
else
{
m_PerceptionBuffer [ channelHotVals . Offset + ChannelOffsets [ j ] ] = channelValues [ j ] ;
}
break ;
}
Array . Copy ( channelHotVals , 0 , m_PerceptionBuffer , cellIndex * ObservationPerCell , ObservationPerCell ) ;
break ;
}
Profiler . EndSample ( ) ;
}
/// <summary>Converts the index of the cell to the 3D point (y is zero)</summary>
CellActivity [ toCellID ] = CellActivity [ fromCellID ] ;
}
/// <summary>Creates a copy of a float array</summary>
/// <returns>float[] of the original data</returns>
/// <param name="array">The array to copy from</parma>
private static float [ ] CreateCopy ( float [ ] array )
{
float [ ] b = new float [ array . Length ] ;
System . Buffer . BlockCopy ( array , 0 , b , 0 , array . Length * sizeof ( float ) ) ;
return b ;
}
/// <summary>Utility method to find the index of a tag</summary>
/// <returns>Index of the tag in DetectableObjects, if it is in there</returns>
/// <param name="tag">The tag to search for</param>
public int IndexOfTag ( string tag )
{
return Array . IndexOf ( DetectableObjects , tag ) ;
}
void OnDrawGizmos ( )
{
if ( ShowGizmos )
Perceive ( ) ;
Vector3 scale = new Vector3 ( CellScaleX , 1 , CellScaleZ ) ;
Vector3 offset = new Vector3 ( 0 , GizmoYOffset , 0 ) ;
Matrix4x4 oldGizmoMatrix = Gizmos . matrix ;
Matrix4x4 cubeTransform = Gizmos . matrix ;
for ( int i = 0 ; i < NumCells ; i + + )
var scale = new Vector3 ( CellScaleX , 1 , CellScaleZ ) ;
var offset = new Vector3 ( 0 , GizmoYOffset , 0 ) ;
var oldGizmoMatrix = Gizmos . matrix ;
for ( var i = 0 ; i < NumCells ; i + + )
Matrix4x4 cubeTransform ;
if ( RotateToAgent )
{
cubeTransform = Matrix4x4 . TRS ( CellToPoint ( i ) + offset , transform . rotation , scale ) ;
}
/// <inheritdoc/>
void ISensor . Update ( ) { }
void ISensor . Update ( )
{
using ( TimerStack . Instance . Scoped ( "GridSensor.Update" ) )
{
Perceive ( ) ;
}
}
/// <summary>Gets the observation shape</summary>
/// <returns>int[] of the observation shape</returns>
{
using ( TimerStack . Instance . Scoped ( "GridSensor.WriteToTensor" ) )
{
Perceive ( ) ;
int index = 0 ;
for ( var h = GridNumSideZ - 1 ; h > = 0 ; h - - ) // height
{