using System; using System.Collections.Generic; using Unity.Collections; using UnityEngine; using UnityEngine.Assertions; using Unity.MLAgents.Sensors; using UnityEngine.Profiling; using Unity.Jobs; using UnityEngine.Jobs; using Unity.Burst; using Unity.Collections.LowLevel.Unsafe; #if UNITY_EDITOR using UnityEditor; #endif namespace Unity.MLAgents.Extensions.Sensors { /// /// Grid-based sensor. /// [AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)] public class GridSensor : SensorComponent, ISensor, IBuiltInSensor { /// /// Name of this grid sensor. /// public string Name; // // Main Parameters // /// /// The width of each grid cell. /// [Header("Grid Sensor Settings")] [Tooltip("The width of each grid cell")] [Range(0.05f, 1000f)] public float CellScaleX = 1f; /// /// The depth of each grid cell. /// [Tooltip("The depth of each grid cell")] [Range(0.05f, 1000f)] public float CellScaleZ = 1f; /// /// The width of the grid . /// [Tooltip("The width of the grid")] [Range(2, 2000)] public int GridNumSideX = 16; /// /// The depth of the grid . /// [Tooltip("The depth of the grid")] [Range(2, 2000)] public int GridNumSideZ = 16; /// /// The height of each grid cell. Changes how much of the vertical axis is observed by a cell. /// [Tooltip("The height of each grid cell. Changes how much of the vertical axis is observed by a cell")] [Range(0.01f, 1000f)] public float CellScaleY = 0.01f; /// /// Rotate the grid based on the direction the agent is facing. /// [Tooltip("Rotate the grid based on the direction the agent is facing")] public bool RotateToAgent; /// /// Array holding the depth of each channel. /// [Tooltip("Array holding the depth of each channel")] public int[] ChannelDepth; /// /// List of tags that are detected. /// [Tooltip("List of tags that are detected")] public string[] DetectableObjects; /// /// The layer mask. /// [Tooltip("The layer mask")] public LayerMask ObserveMask; /// /// Enum describing what kind of depth type the data should be organized as /// public enum GridDepthType { Channel, ChannelHot }; /// /// The data layout that the grid should output. /// [Tooltip("The data layout that the grid should output")] public GridDepthType gridDepthType = GridDepthType.Channel; /// /// The reference of the root of the agent. This is used to disambiguate objects with the same tag as the agent. Defaults to current GameObject. /// [Tooltip("The reference of the root of the agent. This is used to disambiguate objects with the same tag as the agent. Defaults to current GameObject")] public GameObject rootReference; [Header("Collider Buffer Properties")] [Tooltip("The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words" + " the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell.")] public int MaxColliderBufferSize = 500; [Tooltip( "The Estimated Max Number of Colliders to expect per cell. This number is used to " + "pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc " + "Physics API. If the number of colliders found is >= InitialColliderBufferSize the array " + "will be resized to double its current size. The hard coded absolute size is 500.")] public int InitialColliderBufferSize = 4; // Collider[] m_ColliderBuffer; NativeArray m_Hits; NativeArray m_HitIndexes; NativeArray m_BoxcastCommands; protected Dictionary m_DetectableObjectToIndex = new Dictionary(); /// /// The offsets used to specify where within a cell's allotted data, certain observations will be inserted. /// [HideInInspector] public NativeArray ChannelOffsets; public NativeArray PerceptionBuffer => m_PerceptionBuffer; /// /// The main storage of perceptual information. /// protected NativeArray m_PerceptionBuffer; /// /// The default value of the perceptionBuffer when using the ChannelHot DepthType. Used to reset the array/ /// protected NativeArray m_ChannelHotDefaultPerceptionBuffer; /// /// Array of colors displaying the DebugColors for each cell in OnDrawGizmos. Only updated if ShowGizmos. /// protected NativeArray CellActivity; /// /// Array of positions where each position is the center of a cell. /// private NativeArray CellPoints; float[] m_ChannelBuffer; // // Hidden Parameters // /// /// The total number of observations per cell of the grid. Its equivalent to the "channel" on the outgoing tensor. /// [HideInInspector] public int ObservationPerCell; /// /// The total number of observations that this GridSensor provides. It is the length of m_PerceptionBuffer. /// [HideInInspector] public int NumberOfObservations; /// /// Array of Colors needed in order to load the values of the perception buffer to a texture. /// protected Color[] m_PerceptionColors; /// /// Texture where the colors are written to so that they can be compressed in PNG format. /// protected Texture2D m_perceptionTexture2D; // // Utility Constants Calculated on Init // /// /// Number of PNG formated images that are sent to python during training. /// private int NumImages; /// /// Number of relevant channels on the last image that is sent/ /// private int NumChannelsOnLastImage; /// /// Radius of grid, used for normalizing the distance. /// protected float InverseSphereRadius; /// /// Total Number of cells (width*height) /// private int NumCells; /// /// Difference between GridNumSideZ and gridNumSideX. /// protected int DiffNumSideZX = 0; /// /// Offset used for calculating CellToPoint /// protected float OffsetGridNumSide = 7.5f; // (gridNumSideZ - 1) / 2; /// /// Half of the grid in the X direction /// private float HalfOfGridX; /// /// Half of the grid in the z direction /// private float HalfOfGridZ; /// /// Used in the PointToCell method to scale the x value to land in the calculated cell. /// private float PointToCellScalingX; /// /// Used in the PointToCell method to scale the y value to land in the calculated cell. /// private float PointToCellScalingZ; /// /// Bool if initialized or not. /// protected bool Initialized = false; /// /// Array holding the dimensions of the resulting tensor /// private int[] m_Shape; // // Debug Parameters // /// /// Array of Colors used for the grid gizmos. /// [Header("Debug Options")] [Tooltip("Array of Colors used for the grid gizmos")] public Color[] DebugColors; /// /// The height of the gizmos grid. /// [Tooltip("The height of the gizmos grid")] public float GizmoYOffset = 0f; /// /// Whether to show gizmos or not. /// [Tooltip("Whether to show gizmos or not")] public bool ShowGizmos = false; public SensorCompressionType CompressionType = SensorCompressionType.PNG; /// /// List representing the multiple compressed images of all of the grids /// private List compressedImgs; /// /// List representing the sizes of the multiple images so they can be properly reconstructed on the python side /// private List byteSizesBytesList; private Color DebugDefaultColor = new Color(1f, 1f, 1f, 0.25f); internal JobHandle m_boxcastJobHandle; /// public override ISensor CreateSensor() { return this; } /// /// Sets the parameters of the grid sensor /// /// array of strings representing the tags to be detected by the sensor /// array of ints representing the depth of each channel /// enum representing the GridDepthType of the sensor /// float representing the X scaling of each cell /// float representing the Z scaling of each cell /// int representing the number of cells in the X direction. Width of the Grid /// int representing the number of cells in the Z direction. Height of the Grid /// int representing the layer mask to observe /// bool if true then the grid is rotated to the rotation of the transform the rootReference /// array of colors corresponding the the tags in the detectableObjects array public virtual void SetParameters(string[] detectableObjects, int[] channelDepth, GridDepthType gridDepthType, float cellScaleX, float cellScaleZ, int gridWidth, int gridHeight, int observeMaskInt, bool rotateToAgent, Color[] debugColors) { this.ObserveMask = observeMaskInt; this.DetectableObjects = detectableObjects; this.ChannelDepth = channelDepth; this.gridDepthType = gridDepthType; this.CellScaleX = cellScaleX; this.CellScaleZ = cellScaleZ; this.GridNumSideX = gridWidth; this.GridNumSideZ = gridHeight; this.RotateToAgent = rotateToAgent; this.DiffNumSideZX = (GridNumSideZ - GridNumSideX); this.OffsetGridNumSide = (GridNumSideZ - 1f) / 2f; this.DebugColors = debugColors; } /// /// Initializes the constant parameters used within the perceive method call /// public void InitGridParameters() { NumCells = GridNumSideX * GridNumSideZ; var sphereRadiusX = (CellScaleX * GridNumSideX) / Mathf.Sqrt(2); var sphereRadiusZ = (CellScaleZ * GridNumSideZ) / Mathf.Sqrt(2); InverseSphereRadius = 1.0f / Mathf.Max(sphereRadiusX, sphereRadiusZ); DisposeNativeArray(ChannelOffsets); ChannelOffsets = new NativeArray(ChannelDepth.Length, Allocator.Persistent); DiffNumSideZX = (GridNumSideZ - GridNumSideX); OffsetGridNumSide = (GridNumSideZ - 1f) / 2f; HalfOfGridX = CellScaleX * GridNumSideX / 2; HalfOfGridZ = CellScaleZ * GridNumSideZ / 2; PointToCellScalingX = GridNumSideX / (CellScaleX * GridNumSideX); PointToCellScalingZ = GridNumSideZ / (CellScaleZ * GridNumSideZ); } /// /// Initializes the constant parameters that are based on the Grid Depth Type /// Sets the ObservationPerCell and the ChannelOffsets properties /// public virtual void InitDepthType() { switch (gridDepthType) { case GridDepthType.Channel: ObservationPerCell = ChannelDepth.Length; break; case GridDepthType.ChannelHot: ObservationPerCell = 0; ChannelOffsets[ChannelOffsets.Length - 1] = 0; for (var i = 1; i < ChannelDepth.Length; i++) { ChannelOffsets[i] = ChannelOffsets[i - 1] + ChannelDepth[i - 1]; } for (var i = 0; i < ChannelDepth.Length; i++) { ObservationPerCell += ChannelDepth[i]; } break; } // The maximum number of channels in the final output must be less than 255 * 3 because the "number of PNG images" to generate must fit in one byte Assert.IsTrue(ObservationPerCell < (255 * 3), "The maximum number of channels per cell must be less than 255 * 3"); } /// /// Initializes the location of the CellPoints property /// private void InitCellPoints() { if (CellPoints.IsCreated) { CellPoints.Dispose(); } CellPoints = new NativeArray(NumCells, Allocator.Persistent); for (var i = 0; i < NumCells; i++) { CellPoints[i] = CellToPoint(i, false); } } /// /// Initializes the m_ChannelHotDefaultPerceptionBuffer with default data in the case that the grid depth type is ChannelHot /// public virtual void InitChannelHotDefaultPerceptionBuffer() { if (m_ChannelHotDefaultPerceptionBuffer.IsCreated) { m_ChannelHotDefaultPerceptionBuffer.Dispose(); } m_ChannelHotDefaultPerceptionBuffer = new NativeArray(ObservationPerCell, Allocator.Persistent); for (var i = 0; i < ChannelDepth.Length; i++) { if (ChannelDepth[i] > 1) { m_ChannelHotDefaultPerceptionBuffer[ChannelOffsets[i]] = 1; } } } /// /// Initializes the m_PerceptionBuffer as the main data storage property /// Calculates the NumImages and NumChannelsOnLastImage that are used for serializing m_PerceptionBuffer /// public void InitPerceptionBuffer() { if (Application.isPlaying) Initialized = true; NumberOfObservations = ObservationPerCell * NumCells; if (m_PerceptionBuffer.IsCreated) { m_boxcastJobHandle.Complete(); } DisposeNativeArray(m_PerceptionBuffer); m_PerceptionBuffer = new NativeArray(NumberOfObservations, Allocator.Persistent); if (gridDepthType == GridDepthType.ChannelHot) { InitChannelHotDefaultPerceptionBuffer(); } m_PerceptionColors = new Color[NumCells]; NumImages = ObservationPerCell / 3; NumChannelsOnLastImage = ObservationPerCell % 3; if (NumChannelsOnLastImage == 0) NumChannelsOnLastImage = 3; else NumImages += 1; m_ChannelBuffer = new float[ChannelDepth.Length]; DisposeNativeArray(CellActivity); CellActivity = new NativeArray(NumCells, Allocator.Persistent); } void OnEnable() { Start(); } /// /// Calls the initialization methods. Creates the data storing properties used to send the data /// Establishes /// public virtual void Start() { InitGridParameters(); InitDepthType(); InitCellPoints(); InitPerceptionBuffer(); DisposeNativeArray(m_Hits); m_Hits = new NativeArray(NumCells, Allocator.Persistent); DisposeNativeArray(m_HitIndexes); m_HitIndexes = new NativeArray(NumCells, Allocator.Persistent); DisposeNativeArray(m_BoxcastCommands); m_BoxcastCommands = new NativeArray(NumCells, Allocator.Persistent); for (var i = 0; i < DetectableObjects.Length; i++) { m_DetectableObjectToIndex[DetectableObjects[i]] = i; } // Default root reference to current game object if (rootReference == null) rootReference = gameObject; m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell }; compressedImgs = new List(); byteSizesBytesList = new List(); m_perceptionTexture2D = new Texture2D(GridNumSideX, GridNumSideZ, TextureFormat.RGB24, false); } /// void ISensor.Reset() { } [BurstCompile(CompileSynchronously = true)] struct ClearBufferOneHotJob : IJobParallelFor { public NativeArray PerceptionBuf; [ReadOnly] public NativeArray ChannelHotBuf; public int ObservationPerCell; public void Execute(int index) { PerceptionBuf[index] = ChannelHotBuf[index % ObservationPerCell]; } } [BurstCompile(CompileSynchronously = true)] struct ClearBufferChannelJob : IJobParallelFor { public NativeArray PerceptionBuf; public void Execute(int index) { PerceptionBuf[index] = 0; } } [BurstCompile(CompileSynchronously = true)] struct ClearCellActivityJob : IJobParallelFor { public NativeArray CellActivity; public Color DebugDefault; public void Execute(int index) { CellActivity[index] = DebugDefault; } } /// /// Clears the perception buffer before loading in new data. If the gridDepthType is ChannelHot, then it initializes the /// Reset() also reinits the cell activity array (for debug) /// public void ClearPerceptionBuffer() { Profiler.BeginSample("ClearPerceptionBuffer"); { if (gridDepthType == GridDepthType.ChannelHot) { // Copy the default value to the array for (var i = 0; i < m_PerceptionBuffer.Length; i++) { m_PerceptionBuffer[i] = m_ChannelHotDefaultPerceptionBuffer[i % ObservationPerCell]; } } else { for (var i = 0; i < m_PerceptionBuffer.Length; i++) { m_PerceptionBuffer[i] = 0; } } if (ShowGizmos) { // Assign the default color to the cell activities for (var i = 0; i < NumCells; i++) { CellActivity[i] = DebugDefaultColor; } } } Profiler.EndSample(); } void OnDisable() { } void OnDestroy() { DisposeNativeArray(m_Hits); DisposeNativeArray(m_HitIndexes); DisposeNativeArray(m_BoxcastCommands); DisposeNativeArray(CellActivity); DisposeNativeArray(CellPoints); DisposeNativeArray(m_ChannelHotDefaultPerceptionBuffer); DisposeNativeArray(ChannelOffsets); } static void DisposeNativeArray(NativeArray array) where TD : struct { if (array.IsCreated) { array.Dispose(); } } /// Gets the shape of the grid observation /// integer array shape of the grid observation public int[] GetFloatObservationShape() { m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell }; return m_Shape; } /// public string GetName() { return Name; } /// public virtual SensorCompressionType GetCompressionType() { return CompressionType; } /// public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.GridSensor; } /// /// GetCompressedObservation - Calls Perceive then puts the data stored on the perception buffer /// onto the m_perceptionTexture2D to be converted to a byte array and returned /// /// byte[] containing the compressed observation of the grid observation public byte[] GetCompressedObservation() { UpdateBufferFromJob(); var allBytes = new List(); Profiler.BeginSample("GridSensor.GetCompressedObservation"); { for (var i = 0; i < NumImages - 1; i++) { ChannelsToTexture(3 * i, 3); allBytes.AddRange(m_perceptionTexture2D.EncodeToPNG()); } ChannelsToTexture(3 * (NumImages - 1), NumChannelsOnLastImage); allBytes.AddRange(m_perceptionTexture2D.EncodeToPNG()); } Profiler.EndSample(); return allBytes.ToArray(); } /// /// ChannelsToTexture - Takes the channel index and the numChannelsToAdd. /// For each cell and for each channel to add, sets it to a value of the color specified for that cell. /// All colors are then set to the perceptionTexture via SetPixels. /// m_perceptionTexture2D can then be read as an image as it now contains all of the information that was /// stored in the channels /// /// /// protected void ChannelsToTexture(int channelIndex, int numChannelsToAdd) { for (var i = 0; i < NumCells; i++) { for (var j = 0; j < numChannelsToAdd; j++) { m_PerceptionColors[i][j] = m_PerceptionBuffer[i * ObservationPerCell + channelIndex + j]; } } m_perceptionTexture2D.SetPixels(m_PerceptionColors); } [BurstCompile(CompileSynchronously = true)] struct CreateBoxcastBatch : IJobParallelFor { public NativeArray Commands; public Vector3 halfScale; public int ObserveMask; [ReadOnly] public NativeArray CellPoints; public Vector3 Pos; public void Execute(int index) { var cellCenter = Pos + CellPoints[index]; var rotation = Quaternion.identity; Commands[index] = new BoxcastCommand(new Vector3(cellCenter.x, cellCenter.y + 2.0f, cellCenter.z), halfScale, rotation, Vector3.down, 10.0f, ObserveMask); } } [BurstCompile(CompileSynchronously = true)] struct CreateBoxcastRotateBatch : IJobParallelFor { public NativeArray Commands; public Vector3 halfScale; public int ObserveMask; [ReadOnly] public NativeArray CellPoints; [ReadOnly] public Matrix4x4 Mat; public Quaternion Rotation; public void Execute(int index) { var cellCenter = Mat.MultiplyPoint(CellPoints[index]); var rotation = Rotation; Commands[index] = new BoxcastCommand(new Vector3(cellCenter.x, cellCenter.y + 5.0f, cellCenter.z), halfScale, rotation, Vector3.down, 10.0f, ObserveMask); } } // ReSharper disable Unity.PerformanceAnalysis /// /// Perceive - Clears the buffers, calls overlap box on the actual cell (the actual perception part) /// for all found colliders, LoadObjectData is called /// at the end, Perceive returns the float array of the perceptions /// /// A float[] containing all of the information collected from the gridsensor public NativeArray Perceive() { if (!m_PerceptionBuffer.IsCreated || !m_Hits.IsCreated || !m_BoxcastCommands.IsCreated) { return new NativeArray(); } Profiler.BeginSample("GridSensor.Perceive"); { m_boxcastJobHandle.Complete(); JobHandle clearHandle; if (gridDepthType == GridDepthType.ChannelHot) { var clearJob = new ClearBufferOneHotJob { PerceptionBuf = m_PerceptionBuffer, ChannelHotBuf = m_ChannelHotDefaultPerceptionBuffer, ObservationPerCell = ObservationPerCell }; clearHandle = clearJob.Schedule(NumberOfObservations, NumberOfObservations / 12); } else { var clearJob = new ClearBufferChannelJob { PerceptionBuf = m_PerceptionBuffer }; clearHandle = clearJob.Schedule(NumberOfObservations, NumberOfObservations / 12); } if (ShowGizmos) { var gizmoJob = new ClearCellActivityJob { CellActivity = CellActivity, DebugDefault = DebugDefaultColor }; var gizmoHandle = gizmoJob.Schedule(NumCells, NumCells / 12); clearHandle = JobHandle.CombineDependencies(clearHandle, gizmoHandle); } var t = transform; var halfCellScale = new Vector3(CellScaleX / 2f, CellScaleY / 2f, CellScaleZ / 2f); JobHandle createBoxHandle; if (RotateToAgent) { var createBoxcasts = new CreateBoxcastRotateBatch { Commands = m_BoxcastCommands, halfScale = halfCellScale, Mat = t.localToWorldMatrix, CellPoints = CellPoints, ObserveMask = ObserveMask, Rotation = t.rotation }; createBoxHandle = createBoxcasts.Schedule(NumCells, NumCells / 12, clearHandle); } else { var createBoxcasts = new CreateBoxcastBatch { Commands = m_BoxcastCommands, halfScale = halfCellScale, Pos = t.position, CellPoints = CellPoints, ObserveMask = ObserveMask }; createBoxHandle = createBoxcasts.Schedule(NumCells, NumCells / 12, clearHandle); } m_boxcastJobHandle = BoxcastCommand.ScheduleBatch(m_BoxcastCommands, m_Hits, NumCells / 12, createBoxHandle); } Profiler.EndSample(); return m_PerceptionBuffer; } /// /// Parses the array of colliders found within a cell. Finds the closest gameobject to the agent root reference within the cell /// /// Array of the colliders found within the cell /// Number of colliders found. /// The index of the cell /// The center position of the cell protected virtual void ParseColliders(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter) { // Profiler.BeginSample("GridSensor.ParseColliders"); // GameObject closestColliderGo = null; // var minDistanceSquared = float.MaxValue; // // var detectableIndex = -1; // for (var i = 0; i < numFound; i++) // { // if (ReferenceEquals(foundColliders[i], null)) // { // continue; // } // var currentColliderGo = foundColliders[i].gameObject; // // // Continue if the current collider go is the root reference // if (ReferenceEquals(currentColliderGo, rootReference)) // continue; // // var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter); // var currentDistanceSquared = (closestColliderPoint - rootReference.transform.position).sqrMagnitude; // Profiler.EndSample(); // // // Checks if our colliders contain a detectable object // if (m_DetectableObjectToIndex.TryGetValue(currentColliderGo.tag, out detectableIndex) && currentDistanceSquared < minDistanceSquared) // { // } // for (var ii = 0; ii < DetectableObjects.Length; ii++) // { // if (currentColliderGo.CompareTag(DetectableObjects[ii])) // { // minDistanceSquared = currentDistanceSquared; // closestColliderGo = currentColliderGo; // break; // } // } // } // // if (!ReferenceEquals(closestColliderGo, null)) // LoadObjectData(closestColliderGo, cellIndex, detectableIndex, (float)Math.Sqrt(minDistanceSquared) * InverseSphereRadius); // Profiler.EndSample(); } /// /// GetObjectData - returns an array of values that represent the game object /// This is one of the few methods that one may need to override to get their required functionality /// For instance, if one wants specific information about the current gameobject, they can use this method /// to extract it and then return it in an array format. /// /// /// A float[] containing the data that holds the representative information of the passed in gameObject /// /// The game object that was found colliding with a certain cell /// The index of the type (tag) of the gameObject. /// (e.g., if this GameObject had the 3rd tag out of 4, type_index would be 2.0f) /// A float between 0 and 1 describing the ratio of /// the distance currentColliderGo is compared to the edge of the gridsensor /// /// Here is an example of extenind GetObjectData to include information about a potential Rigidbody: /// /// protected override float[] GetObjectData(GameObject currentColliderGo, /// float type_index, float normalized_distance) /// { /// float[] channelValues = new float[ChannelDepth.Length]; // ChannelDepth.Length = 4 in this example /// channelValues[0] = type_index; /// Rigidbody goRb = currentColliderGo.GetComponent<Rigidbody>(); /// if (goRb != null) /// { /// channelValues[1] = goRb.velocity.x; /// channelValues[2] = goRb.velocity.y; /// channelValues[3] = goRb.velocity.z; /// } /// return channelValues; /// } /// /// protected virtual float[] GetObjectData(GameObject currentColliderGo, float typeIndex, float normalizedDistance) { Array.Clear(m_ChannelBuffer, 0, m_ChannelBuffer.Length); m_ChannelBuffer[0] = typeIndex; return m_ChannelBuffer; } /// /// Runs basic validation assertions to check that the values can be normalized /// /// The values to be validated /// The gameobject used for better error messages protected virtual void ValidateValues(float[] channelValues, GameObject currentColliderGo) { for (var j = 0; j < channelValues.Length; j++) { if (channelValues[j] < 0) throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be non-negative, was " + channelValues[j]); if (channelValues[j] > ChannelDepth[j]) throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be less than ChannelDepth[" + j + "] (" + ChannelDepth[j] + "), was " + channelValues[j]); } } /// /// LoadObjectData - If the GameObject matches a tag, GetObjectData is called to extract the data from the GameObject /// then the data is transformed based on the GridDepthType of the gridsensor. /// Further documetation on the GridDepthType can be found below /// /// The game object that was found colliding with a certain cell /// The index of the current cell /// Index into the DetectableObjects array. /// A float between 0 and 1 describing the ratio of /// the distance currentColliderGo is compared to the edge of the gridsensor protected virtual void LoadObjectData(GameObject currentColliderGo, int cellIndex, int detectableIndex, float normalizedDistance) { Profiler.BeginSample("GridSensor.LoadObjectData"); var offset = cellIndex * ObservationPerCell; for (var ii = 0; ii < ObservationPerCell; ii++) { m_PerceptionBuffer[offset + ii] = 0f; } // TODO: Create the array already then set the values using "out" in GetObjectData // Using i+1 as the type index as "0" represents "empty" var channelValues = GetObjectData(currentColliderGo, (float)detectableIndex + 1, normalizedDistance); ValidateValues(channelValues, currentColliderGo); if (ShowGizmos) { var debugRayColor = Color.white; if (DebugColors.Length > 0) { debugRayColor = DebugColors[detectableIndex]; } CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); } switch (gridDepthType) { case GridDepthType.Channel: { // The observations are "channel based" so each grid is WxHxC where C is the number of channels // This typically means that each channel value is normalized between 0 and 1 // If channelDepth is 1, the value is assumed normalized, else the value is normalized by the channelDepth // The channels are then stored consecutively in PerceptionBuffer. // NOTE: This is the only grid type that uses floating point values // For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams: // channelValues = {2, 1} // ObservationPerCell = channelValues.Length // channelValues = {2f/5f, 1f/3f} = {.4, .33..} // Array.Copy(channelValues, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell); for (var j = 0; j < channelValues.Length; j++) { channelValues[j] /= ChannelDepth[j]; } for (var ii = 0; ii < ObservationPerCell; ii++) { m_PerceptionBuffer[offset + ii] = channelValues[ii]; } break; } case GridDepthType.ChannelHot: { // The observations are "channel hot" so each grid is WxHxD where D is the sum of all of the channel depths // The opposite of the "channel based" case, the channel values are represented as one hot vector per channel and then concatenated together // Thus channelDepth is assumed to be greater than 1. // For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams, // channelValues = {2, 1} // channelOffsets = {5, 3} // ObservationPerCell = 5 + 3 = 8 // channelHotVals = {0, 0, 1, 0, 0, 0, 1, 0} // Array.Copy(channelHotVals, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell); for (var j = 0; j < channelValues.Length; j++) { if (ChannelDepth[j] > 1) { m_PerceptionBuffer[offset + (int)channelValues[j] + ChannelOffsets[j]] = 1f; } else { m_PerceptionBuffer[offset + ChannelOffsets[j]] = channelValues[j]; } } break; } } Profiler.EndSample(); } /// Converts the index of the cell to the 3D point (y is zero) /// Vector3 of the position of the center of the cell /// The index of the cell /// Bool weather to transform the point to the current transform protected Vector3 CellToPoint(int cell, bool shouldTransformPoint = true) { var x = (cell % GridNumSideZ - OffsetGridNumSide) * CellScaleX; var z = (cell / GridNumSideZ - OffsetGridNumSide) * CellScaleZ - DiffNumSideZX; if (shouldTransformPoint) return transform.TransformPoint(new Vector3(x, 0, z)); return new Vector3(x, 0, z); } /// Finds the cell in which the given global point falls /// /// The index of the cell in which the global point falls or -1 if the point does not fall into a cell /// /// The 3D point in global space public int PointToCell(Vector3 globalPoint) { var point = transform.InverseTransformPoint(globalPoint); if (point.x < -HalfOfGridX || point.x > HalfOfGridX || point.z < -HalfOfGridZ || point.z > HalfOfGridZ) return -1; var x = point.x + HalfOfGridX; var z = point.z + HalfOfGridZ; var _x = (int)Mathf.Floor(x * PointToCellScalingX); var _z = (int)Mathf.Floor(z * PointToCellScalingZ); return GridNumSideX * _z + _x; } // /// Copies the data from one cell to another // /// index of the cell to copy from // /// index of the cell to copy into // protected void CopyCellData(int fromCellID, int toCellID) // { // Array.Copy(m_PerceptionBuffer, // fromCellID * ObservationPerCell, // m_PerceptionBuffer, // toCellID * ObservationPerCell, // ObservationPerCell); // if (ShowGizmos) // CellActivity[toCellID] = CellActivity[fromCellID]; // } void OnDrawGizmos() { if (ShowGizmos) { if (Application.isEditor && !Application.isPlaying) Start(); Perceive(); UpdateBufferFromJob(); var scale = new Vector3(CellScaleX, CellScaleY, CellScaleZ); var offset = new Vector3(0, 5.0f, 0); var oldGizmoMatrix = Gizmos.matrix; for (var i = 0; i < NumCells; i++) { Matrix4x4 cubeTransform; if (RotateToAgent) { cubeTransform = Matrix4x4.TRS(CellToPoint(i) + offset, transform.rotation, scale); } else { cubeTransform = Matrix4x4.TRS(CellToPoint(i, false) + transform.position + offset, Quaternion.identity, scale); } Gizmos.matrix = oldGizmoMatrix * cubeTransform; Gizmos.color = CellActivity[i]; Gizmos.DrawCube(Vector3.zero, Vector3.one); } Gizmos.matrix = oldGizmoMatrix; if (Application.isEditor && !Application.isPlaying) DestroyImmediate(m_perceptionTexture2D); } } /// void ISensor.Update() { Profiler.BeginSample("GridSensor.Update"); { Perceive(); } Profiler.EndSample(); } /// Gets the observation shape /// int[] of the observation shape public override int[] GetObservationShape() { m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell }; return m_Shape; } /// public int Write(ObservationWriter writer) { UpdateBufferFromJob(); var index = 0; Profiler.BeginSample("GridSensor.WriteToTensor"); { for (var h = GridNumSideZ - 1; h >= 0; h--) // height { for (var w = 0; w < GridNumSideX; w++) // width { for (var d = 0; d < ObservationPerCell; d++) // depth { writer[h, w, d] = m_PerceptionBuffer[index]; index++; } } } Profiler.EndSample(); } return index; } internal void UpdateBufferFromJob() { Profiler.BeginSample("UpdateBufferFromJob"); { m_boxcastJobHandle.Complete(); for (var cellIndex = 0; cellIndex < NumCells; cellIndex++) { var c = m_Hits[cellIndex].collider; if (ReferenceEquals(c, null)) { continue; } for (var i = 0; i < DetectableObjects.Length; i++) { if (c.CompareTag(DetectableObjects[i])) { LoadObjectData(c.gameObject, cellIndex, i, 0); break; } } } } Profiler.EndSample(); } } }