浏览代码

Move GridSensor into main package (#5256)

* move OneHotGridSensor into main package

* changelog and migration guide

* remove old doc

* check if physics module presents
/check-for-ModelOverriders
GitHub 4 年前
当前提交
2a9c8f0d
共有 47 个文件被更改,包括 928 次插入3439 次删除
  1. 60
      Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockAgentGridCollab.prefab
  2. 141
      Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockCollabAreaGrid.prefab
  3. 1
      com.unity.ml-agents.extensions/Documentation~/com.unity.ml-agents.extensions.md
  4. 6
      com.unity.ml-agents.extensions/Runtime/Sensors/CountingGridSensor.cs
  5. 5
      com.unity.ml-agents/CHANGELOG.md
  6. 2
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  7. 4
      com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
  8. 28
      docs/Migrating.md
  9. 2
      com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta
  10. 28
      com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs
  11. 9
      com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs
  12. 8
      com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs
  13. 2
      com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs
  14. 37
      com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs
  15. 24
      com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs
  16. 143
      com.unity.ml-agents.extensions/Tests/Runtime/Sensors/CountingGridSensorTests.cs
  17. 11
      com.unity.ml-agents.extensions/Tests/Runtime/Sensors/CountingGridSensorTests.cs.meta
  18. 108
      com.unity.ml-agents/Editor/GridSensorComponentEditor.cs
  19. 267
      com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs
  20. 293
      com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs
  21. 1001
      com.unity.ml-agents.extensions/Documentation~/images/gridobs-vs-vectorobs.gif
  22. 20
      com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-camera.png
  23. 94
      com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-gridsensor.png
  24. 67
      com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-raycast.png
  25. 79
      com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example.png
  26. 1001
      com.unity.ml-agents.extensions/Documentation~/images/gridsensor-debug.png
  27. 230
      com.unity.ml-agents.extensions/Documentation~/Grid-Sensor.md
  28. 106
      com.unity.ml-agents.extensions/Editor/GridSensorComponentEditor.cs
  29. 254
      com.unity.ml-agents.extensions/Runtime/Sensors/BoxOverlapChecker.cs
  30. 328
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensorComponent.cs
  31. 8
      com.unity.ml-agents.extensions/Tests/Editor/GridSensors.meta
  32. 0
      /com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta
  33. 0
      /com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs.meta
  34. 0
      /com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs
  35. 0
      /com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs
  36. 0
      /com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs.meta
  37. 0
      /com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs.meta
  38. 0
      /com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs.meta
  39. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs
  40. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs
  41. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs
  42. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs
  43. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs.meta
  44. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs.meta
  45. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs.meta
  46. 0
      /com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs.meta

60
Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockAgentGridCollab.prefab


serializedVersion: 6
m_Component:
- component: {fileID: 2709359580712052713}
- component: {fileID: 2709359580712052712}
- component: {fileID: 1548337883655231979}
m_Layer: 0
m_Name: GridSensor
m_TagString: Untagged

m_Father: {fileID: 2708762399863795223}
m_RootOrder: 1
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &2709359580712052712
--- !u!114 &1548337883655231979
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}

m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 801669c0cdece6b40b2e741ad0b119ac, type: 3}
m_Script: {fileID: 11500000, guid: 2a501962d056745d1a30e99146ee39fe, type: 3}
Name: AgentGrid
CellScaleX: 1
CellScaleZ: 1
GridNumSideX: 20
GridNumSideZ: 20
CellScaleY: 0.5
RotateToAgent: 1
ChannelDepth: 07000000
DetectableObjects:
m_SensorName: GridSensor
m_CellScale: {x: 1, y: 0.5, z: 1}
m_GridSize: {x: 20, y: 1, z: 20}
m_RotateWithAgent: 1
m_DetectableTags:
- wall
- agent
- goal

ObserveMask:
m_ColliderMask:
gridDepthType: 1
rootReference: {fileID: 2710286047221272849}
MaxColliderBufferSize: 500
InitialColliderBufferSize: 16
ObservationPerCell: 7
NumberOfObservations: 2800
ChannelOffsets: 00000000
DebugColors:
m_MaxColliderBufferSize: 500
m_InitialColliderBufferSize: 16
m_DebugColors:
- {r: 0, g: 0.51824737, b: 1, a: 1}
- {r: 0.4680206, g: 0.7058824, b: 0.35155708, a: 1}
- {r: 1, g: 0.99570733, b: 0.984, a: 1}
- {r: 0.4811321, g: 0.4811321, b: 0.4811321, a: 1}
- {r: 0.3584906, g: 0.3584906, b: 0.3584906, a: 0}
GizmoYOffset: 0
ShowGizmos: 0
CompressionType: 1
- {r: 0, g: 0.5176471, b: 1, a: 0}
- {r: 0.46666667, g: 0.7058824, b: 0.3529412, a: 0}
- {r: 1, g: 0.99607843, b: 0.9843137, a: 0}
- {r: 0.48235294, g: 0.48235294, b: 0.48235294, a: 0}
- {r: 0.35686275, g: 0.35686275, b: 0.35686275, a: 0}
m_GizmoYOffset: 0
m_ShowGizmos: 0
m_CompressionType: 1
m_ObservationStacks: 1
--- !u!1 &2709573194145405553
GameObject:
m_ObjectHideFlags: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

141
Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockCollabAreaGrid.prefab


onTriggerExitEvent:
m_PersistentCalls:
m_Calls: []
--- !u!114 &1809664679221531284
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 8191066182862526894}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 2a501962d056745d1a30e99146ee39fe, type: 3}
m_Name:
m_EditorClassIdentifier:
m_SensorName: GridSensor
m_CellScale: {x: 1, y: 0.01, z: 1}
m_GridSize: {x: 20, y: 1, z: 20}
m_RotateWithAgent: 1
m_DetectableTags:
- wall
- agent
- goal
- blockSmall
- blockLarge
- blockVeryLarge
m_ColliderMask:
serializedVersion: 2
m_Bits: 1
m_MaxColliderBufferSize: 500
m_InitialColliderBufferSize: 16
m_DebugColors:
- {r: 0, g: 0, b: 0, a: 0}
- {r: 0, g: 0.5176471, b: 1, a: 0}
- {r: 0.46666667, g: 0.7058824, b: 0.3529412, a: 0}
- {r: 1, g: 0.99607843, b: 0.9843137, a: 0}
- {r: 0.48235294, g: 0.48235294, b: 0.48235294, a: 0}
- {r: 0.35686275, g: 0.35686275, b: 0.35686275, a: 0}
m_GizmoYOffset: 0
m_ShowGizmos: 0
m_CompressionType: 1
m_ObservationStacks: 1
m_UseOneHotTag: 1
m_CountColliders: 0
--- !u!1 &8191066182918326564
GameObject:
m_ObjectHideFlags: 0

UseRandomAgentPosition: 1
UseRandomBlockRotation: 1
UseRandomBlockPosition: 1
--- !u!114 &4609315540733531199
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 8696048509000480032}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 2a501962d056745d1a30e99146ee39fe, type: 3}
m_Name:
m_EditorClassIdentifier:
m_SensorName: GridSensor
m_CellScale: {x: 1, y: 0.01, z: 1}
m_GridSize: {x: 20, y: 1, z: 20}
m_RotateWithAgent: 1
m_DetectableTags:
- wall
- agent
- goal
- blockSmall
- blockLarge
- blockVeryLarge
m_ColliderMask:
serializedVersion: 2
m_Bits: 1
m_MaxColliderBufferSize: 500
m_InitialColliderBufferSize: 16
m_DebugColors:
- {r: 0, g: 0, b: 0, a: 0}
- {r: 0, g: 0.5176471, b: 1, a: 0}
- {r: 0.46666667, g: 0.7058824, b: 0.3529412, a: 0}
- {r: 1, g: 0.99607843, b: 0.9843137, a: 0}
- {r: 0.48235294, g: 0.48235294, b: 0.48235294, a: 0}
- {r: 0.35686275, g: 0.35686275, b: 0.35686275, a: 0}
m_GizmoYOffset: 0
m_ShowGizmos: 0
m_CompressionType: 1
m_ObservationStacks: 1
m_UseOneHotTag: 1
m_CountColliders: 0
--- !u!1 &8821353056066081524
GameObject:
m_ObjectHideFlags: 0

m_FallbackScreenDPI: 96
m_DefaultSpriteDPI: 96
m_DynamicPixelsPerUnit: 1
--- !u!114 &6319243058783963332
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 9116780590443581137}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 2a501962d056745d1a30e99146ee39fe, type: 3}
m_Name:
m_EditorClassIdentifier:
m_SensorName: GridSensor
m_CellScale: {x: 1, y: 0.01, z: 1}
m_GridSize: {x: 20, y: 1, z: 20}
m_RotateWithAgent: 1
m_DetectableTags:
- wall
- agent
- goal
- blockSmall
- blockLarge
- blockVeryLarge
m_ColliderMask:
serializedVersion: 2
m_Bits: 1
m_MaxColliderBufferSize: 500
m_InitialColliderBufferSize: 16
m_DebugColors:
- {r: 0, g: 0, b: 0, a: 0}
- {r: 0, g: 0.5176471, b: 1, a: 0}
- {r: 0.46666667, g: 0.7058824, b: 0.3529412, a: 0}
- {r: 1, g: 0.99607843, b: 0.9843137, a: 0}
- {r: 0.48235294, g: 0.48235294, b: 0.48235294, a: 0}
- {r: 0.35686275, g: 0.35686275, b: 0.35686275, a: 0}
m_GizmoYOffset: 0
m_ShowGizmos: 0
m_CompressionType: 1
m_ObservationStacks: 1
m_UseOneHotTag: 1
m_CountColliders: 0
--- !u!1001 &6067781793364901444
PrefabInstance:
m_ObjectHideFlags: 0

m_RemovedComponents:
- {fileID: 2709359580712052712, guid: ac01d0f42c5e1463e943632a60d99967, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: ac01d0f42c5e1463e943632a60d99967, type: 3}
--- !u!1 &8191066182862526894 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 2709359580712052714, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
m_PrefabInstance: {fileID: 6067781793364901444}
m_PrefabAsset: {fileID: 0}
--- !u!114 &8190299122290044756 stripped
MonoBehaviour:
m_CorrespondingSourceObject: {fileID: 2710286047221272848, guid: ac01d0f42c5e1463e943632a60d99967,

m_RemovedComponents:
- {fileID: 2709359580712052712, guid: ac01d0f42c5e1463e943632a60d99967, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: ac01d0f42c5e1463e943632a60d99967, type: 3}
--- !u!1 &9116780590443581137 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 2709359580712052714, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
m_PrefabInstance: {fileID: 6565363751102736699}
m_PrefabAsset: {fileID: 0}
--- !u!114 &9115291448867436587 stripped
MonoBehaviour:
m_CorrespondingSourceObject: {fileID: 2710286047221272848, guid: ac01d0f42c5e1463e943632a60d99967,

m_RemovedComponents:
- {fileID: 2709359580712052712, guid: ac01d0f42c5e1463e943632a60d99967, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: ac01d0f42c5e1463e943632a60d99967, type: 3}
--- !u!1 &8696048509000480032 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 2709359580712052714, guid: ac01d0f42c5e1463e943632a60d99967,
type: 3}
m_PrefabInstance: {fileID: 6716844123244810954}
m_PrefabAsset: {fileID: 0}
--- !u!114 &8695281997955662810 stripped
MonoBehaviour:
m_CorrespondingSourceObject: {fileID: 2710286047221272848, guid: ac01d0f42c5e1463e943632a60d99967,

1
com.unity.ml-agents.extensions/Documentation~/com.unity.ml-agents.extensions.md


| _Tests_ | Contains the unit tests for the package. |
The Runtime directory currently contains these features:
* [Grid-based sensor](Grid-Sensor.md)
* Physics-based sensors
* [Input System Package Integration](InputActuatorComponent.md)

6
com.unity.ml-agents.extensions/Runtime/Sensors/CountingGridSensor.cs


/// </summary>
/// <param name="name">The sensor name</param>
/// <param name="cellScale">The scale of each cell in the grid</param>
/// <param name="gridNum">Number of cells on each side of the grid</param>
/// <param name="gridSize">Number of cells on each side of the grid</param>
Vector3Int gridNum,
Vector3Int gridSize,
) : base(name, cellScale, gridNum, detectableTags, compression)
) : base(name, cellScale, gridSize, detectableTags, compression)
{
CompressionType = SensorCompressionType.None;
}

5
com.unity.ml-agents/CHANGELOG.md


- `ISensor.GetCompressionType()` was removed, and `GetCompressionSpec()` was added. The `ISparseChannelSensor`
interface was removed. (#5164)
- The abstract method `SensorComponent.GetObservationShape()` was no longer being called, so it has been removed. (#5172)
- `SensorComponent.CreateSensor()` was replaced with `SensorComponent.CreateSensor()`, which returns an `ISensor[]`. (#5181)
- `SensorComponent.CreateSensor()` was replaced with `SensorComponent.CreateSensors()`, which returns an `ISensor[]`. (#5181)
- `Match3Sensor` was refactored to produce cell and special type observations separately, and `Match3SensorComponent` now
produces two `Match3Sensor`s (unless there are no special types). Previously trained models will have different observation
sizes and will need to be retrained. (#5181)

- `GridSensor` has been refactored and moved to main package, with changes to both sensor interfaces and behaviors.
Exsisting GridSensor created by extension package will not work in newer version. Previously trained models will
need to be retrained. Please see the Migration Guide for more details. (#5256)
### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

2
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


// For each ray, write the information to the observation buffer
for (var rayIndex = 0; rayIndex < numRays; rayIndex++)
{
m_RayPerceptionOutput.RayOutputs[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations);
m_RayPerceptionOutput.RayOutputs?[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations);
}
// Finally, add the observations to the ObservationWriter

4
com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs


// Can't actually create an Agent with InferenceOnly and no model, so change back
behaviorParams.BehaviorType = BehaviorType.Default;
#if MLA_UNITY_PHSYICS_MODULE
#if MLA_UNITY_PHYSICS_MODULE
var sensorComponent = gameObject.AddComponent<RayPerceptionSensorComponent3D>();
sensorComponent.SensorName = "ray3d";
sensorComponent.DetectableTags = new List<string> { "Player", "Respawn" };

decisionRequester.DecisionPeriod = 2;
decisionRequester.TakeActionsBetweenDecisions = true;
#if MLA_UNITY_PHSYICS_MODULE
#if MLA_UNITY_PHYSICS_MODULE
// Initialization should set up the sensors
Assert.IsNotNull(sensorComponent.RaySensor);
#endif

28
docs/Migrating.md


current `BoardSize`. The values returned by `GetCurrentBoardSize()` must be less than or equal to the corresponding
values from `GetMaxBoardSize()`.
### GridSensor changes
The sensor configuration has changed:
* The sensor implementation has been refactored and exsisting GridSensor created from extension package
will not work in newer version. Some errors might show up when loading the old sensor in the scene.
You'll need to remove the old sensor and create a new GridSensor.
* These parameters names have changed but still refer to the same concept in the sensor: `GridNumSide` -> `GridSize`,
`RotateToAgent` -> `RotateWithAgent`, `ObserveMask` -> `ColliderMask`, `DetectableObjects` -> `DetectableTags`
* `RootReference` is removed and the sensor component's GameObject will always be ignored for hit results.
* `DepthType` (`ChanelBase`/`ChannelHot`) option and `ChannelDepth` are removed. Now the default is
one-hot encoding for detected tag. If you were using original GridSensor without overriding any method,
switching to new GridSensor will produce similar effect for training although the actual observations
will be slightly different.
For creating your GridSensor implementation with custom data:
* To create custom GridSensor, derive from `GridSensorBase` instead of `GridSensor`. Besides overriding
`GetObjectData()`, you will also need to consider override `GetCellObservationSize()`, `IsDataNormalized()`
and `GetProcessCollidersMethod()` according to the data you collect. Also you'll need to override
`GridSensorComponent.GetGridSensors()` and return your custom GridSensor.
* The input argument `tagIndex` in `GetObjectData()` has changed from 1-indexed to 0-indexed and the
data type changed from `float` to `int`. The index of first detectable tag will be 0 instead of 1.
`normalizedDistance` was removed from input.
* The observation data should be written to the input `dataBuffer` instead of creating and returning a new array.
* Removed the constraint of all data required to be normalized. You should specify it in `IsDataNormalized()`.
Sensors with non-normalized data cannot use PNG compression type.
* The sensor will not further encode the data recieved from `GetObjectData()` anymore. The values
recieved from `GetObjectData()` will be the observation sent to the trainer.
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations
- If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator

2
com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta


fileFormatVersion: 2
guid: 62dc58d0ddf584affa1f269e9c5791c2
guid: 584686b36fcb2435c8be47d70c332ed0
MonoImporter:
externalObjects: {}
serializedVersion: 2

28
com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs


using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using Unity.MLAgents.Sensors;
using Object = UnityEngine.Object;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
namespace Unity.MLAgents.Extensions.Sensors
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// The way the GridSensor process detected colliders in a cell.

/// </summary>
/// <param name="name">The sensor name</param>
/// <param name="cellScale">The scale of each cell in the grid</param>
/// <param name="gridNum">Number of cells on each side of the grid</param>
/// <param name="gridSize">Number of cells on each side of the grid</param>
Vector3Int gridNum,
Vector3Int gridSize,
string[] detectableTags,
SensorCompressionType compression
)

m_GridSize = gridNum;
m_GridSize = gridSize;
m_DetectableTags = detectableTags;
CompressionType = compression;

{
if (!ReferenceEquals(null, m_PerceptionTexture))
{
DestroyTexture(m_PerceptionTexture);
Utilities.DestroyTexture(m_PerceptionTexture);
}
}
static void DestroyTexture(Texture2D texture)
{
if (Application.isEditor)
{
// Edit Mode tests complain if we use Destroy()
// TODO move to extension methods for UnityEngine.Object?
Object.DestroyImmediate(texture);
}
else
{
Object.Destroy(texture);
}
}
}

9
com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs


using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// Grid-based sensor with one-hot observations.

/// </summary>
/// <param name="name">The sensor name</param>
/// <param name="cellScale">The scale of each cell in the grid</param>
/// <param name="gridNum">Number of cells on each side of the grid</param>
/// <param name="gridSize">Number of cells on each side of the grid</param>
Vector3Int gridNum,
Vector3Int gridSize,
) : base(name, cellScale, gridNum, detectableTags, compression)
) : base(name, cellScale, gridSize, detectableTags, compression)
{
}

8
com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs


#if MLA_UNITY_PHYSICS_MODULE
using Unity.MLAgents.Extensions.Sensors;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Tests.GridSensors
namespace Unity.MLAgents.Tests
{
internal class TestBoxOverlapChecker : BoxOverlapChecker
{

var testGo = new GameObject("test");
testGo.transform.position = Vector3.zero;
var gridSensorComponent = testGo.AddComponent<SimpleTestGridSensorComponent>();
gridSensorComponent.SetComponentParameters(useGridSensorBase: true, useOneHotTag: true, countColliders: true);
gridSensorComponent.SetComponentParameters(useGridSensorBase: true, useTestingGridSensor: true);
var sensors = gridSensorComponent.CreateSensors();
int numChecker = 0;
foreach (var sensor in sensors)

}
}
}
#endif

2
com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs


using System;
using System.Linq;
namespace Unity.MLAgents.Extensions.Tests.GridSensors
namespace Unity.MLAgents.Tests
{
public static class GridObsTestUtils
{

37
com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs


#if MLA_UNITY_PHYSICS_MODULE
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.GridSensors
namespace Unity.MLAgents.Tests
{
public class GridSensorTests
{

}
[Test]
public void TestCountingSensor()
{
testGo.tag = k_Tag2;
string[] tags = { k_Tag1, k_Tag2 };
gridSensorComponent.SetComponentParameters(tags, countColliders: true);
var gridSensor = (CountingGridSensor)gridSensorComponent.CreateSensors()[0];
Assert.AreEqual(gridSensor.PerceptionBuffer.Length, 10 * 10 * 2);
gridSensor.Update();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new float[] { 1, 0 }, 4);
float[] expectedDefault = new float[] { 0, 0 };
GridObsTestUtils.AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault);
var boxGo2 = new GameObject("block");
boxGo2.tag = k_Tag1;
boxGo2.transform.position = new Vector3(3.1f, 0f, 3f);
boxGo2.AddComponent<BoxCollider>();
gridSensor.Update();
subarrayIndicies = new int[] { 77, 78, 87, 88 };
expectedSubarrays = GridObsTestUtils.DuplicateArray(new float[] { 2, 0 }, 4);
expectedDefault = new float[] { 0, 0 };
GridObsTestUtils.AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault);
Object.DestroyImmediate(boxGo2);
}
[Test]
public void TestCustomSensorInvalidData()
{
testGo.tag = k_Tag2;

{
testGo.tag = k_Tag2;
string[] tags = { k_Tag1, k_Tag2 };
gridSensorComponent.SetComponentParameters(tags, useOneHotTag: true, countColliders: true, useTestingGridSensor: true);
gridSensorComponent.SetComponentParameters(tags, useOneHotTag: true, useGridSensorBase: true, useTestingGridSensor: true);
var gridSensors = gridSensorComponent.CreateSensors();
Assert.IsNotNull(((GridSensorBase)gridSensors[0]).m_BoxOverlapChecker);
Assert.IsNull(((GridSensorBase)gridSensors[1]).m_BoxOverlapChecker);

}
}
}
#endif

24
com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs


using System.Linq;
using System.Collections.Generic;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.GridSensors
namespace Unity.MLAgents.Tests
{
public static class TestGridSensorConfig
{

{
return TestGridSensorConfig.ParseAllColliders ? ProcessCollidersMethod.ProcessAllColliders : ProcessCollidersMethod.ProcessClosestColliders;
}
protected override void GetObjectData(GameObject detectedObject, int typeIndex, float[] dataBuffer)
{
for (var i = 0; i < DummyData.Length; i++)

public class SimpleTestGridSensorComponent : GridSensorComponent
{
bool m_UseOneHotTag;
var sensorList = base.GetGridSensors().ToList();
List<GridSensorBase> sensorList = new List<GridSensorBase>();
if (m_UseOneHotTag)
{
var testSensor = new OneHotGridSensor(
SensorName,
CellScale,
GridSize,
DetectableTags,
CompressionType
);
sensorList.Add(testSensor);
}
if (m_UseGridSensorBase)
{
var testSensor = new GridSensorBase(

SensorCompressionType compression = SensorCompressionType.None,
bool rotateWithAgent = false,
bool useOneHotTag = false,
bool countColliders = false,
bool useTestingGridSensor = false,
bool useGridSensorBase = false
)

ColliderMask = colliderMaskInt < 0 ? LayerMask.GetMask("Default") : colliderMaskInt;
RotateWithAgent = rotateWithAgent;
CompressionType = compression;
UseOneHotTag = useOneHotTag;
CountColliders = countColliders;
m_UseOneHotTag = useOneHotTag;
m_UseGridSensorBase = useGridSensorBase;
m_UseTestingGridSensor = useTestingGridSensor;
}

143
com.unity.ml-agents.extensions/Tests/Runtime/Sensors/CountingGridSensorTests.cs


using System;
using System.Collections;
using System.Linq;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Sensors;
using Object = UnityEngine.Object;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class CountingGridSensorTests
{
GameObject testGo;
GameObject boxGo;
TestCountingGridSensorComponent gridSensorComponent;
// Use built-in tags
const string k_Tag1 = "Player";
const string k_Tag2 = "Respawn";
[UnitySetUp]
public IEnumerator SetupScene()
{
testGo = new GameObject("test");
testGo.transform.position = Vector3.zero;
gridSensorComponent = testGo.AddComponent<TestCountingGridSensorComponent>();
boxGo = new GameObject("block");
boxGo.tag = k_Tag1;
boxGo.transform.position = new Vector3(3f, 0f, 3f);
boxGo.AddComponent<BoxCollider>();
yield return null;
}
[TearDown]
public void ClearScene()
{
Object.DestroyImmediate(boxGo);
Object.DestroyImmediate(testGo);
}
public class TestCountingGridSensorComponent : GridSensorComponent
{
public void SetParameters(string[] detectableTags)
{
DetectableTags = detectableTags;
CellScale = new Vector3(1, 0.01f, 1);
GridSize = new Vector3Int(10, 1, 10);
ColliderMask = LayerMask.GetMask("Default");
RotateWithAgent = false;
CompressionType = SensorCompressionType.None;
}
protected override GridSensorBase[] GetGridSensors()
{
return new GridSensorBase[] {
new CountingGridSensor(
"TestSensor",
CellScale,
GridSize,
DetectableTags,
CompressionType) };
}
}
// Copied from GridSensorTests in main package
public static float[][] DuplicateArray(float[] array, int numCopies)
{
float[][] duplicated = new float[numCopies][];
for (int i = 0; i < numCopies; i++)
{
duplicated[i] = array;
}
return duplicated;
}
// Copied from GridSensorTests in main package
public static void AssertSubarraysAtIndex(float[] total, int[] indicies, float[][] expectedArrays, float[] expectedDefaultArray)
{
int totalIndex = 0;
int subIndex = 0;
int subarrayIndex = 0;
int lenOfData = expectedDefaultArray.Length;
int numArrays = total.Length / lenOfData;
for (int i = 0; i < numArrays; i++)
{
totalIndex = i * lenOfData;
if (indicies.Contains(i))
{
subarrayIndex = Array.IndexOf(indicies, i);
for (subIndex = 0; subIndex < lenOfData; subIndex++)
{
Assert.AreEqual(expectedArrays[subarrayIndex][subIndex], total[totalIndex],
"Expected " + expectedArrays[subarrayIndex][subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]);
totalIndex++;
}
}
else
{
for (subIndex = 0; subIndex < lenOfData; subIndex++)
{
Assert.AreEqual(expectedDefaultArray[subIndex], total[totalIndex],
"Expected default value " + expectedDefaultArray[subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]);
totalIndex++;
}
}
}
}
[Test]
public void TestCountingSensor()
{
string[] tags = { k_Tag1, k_Tag2 };
gridSensorComponent.SetParameters(tags);
var gridSensor = (CountingGridSensor)gridSensorComponent.CreateSensors()[0];
Assert.AreEqual(gridSensor.PerceptionBuffer.Length, 10 * 10 * 2);
gridSensor.Update();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = DuplicateArray(new float[] { 1, 0 }, 4);
float[] expectedDefault = new float[] { 0, 0 };
AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault);
var boxGo2 = new GameObject("block");
boxGo2.tag = k_Tag1;
boxGo2.transform.position = new Vector3(3.1f, 0f, 3f);
boxGo2.AddComponent<BoxCollider>();
gridSensor.Update();
subarrayIndicies = new int[] { 77, 78, 87, 88 };
expectedSubarrays = DuplicateArray(new float[] { 2, 0 }, 4);
expectedDefault = new float[] { 0, 0 };
AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault);
Object.DestroyImmediate(boxGo2);
}
}
}

11
com.unity.ml-agents.extensions/Tests/Runtime/Sensors/CountingGridSensorTests.cs.meta


fileFormatVersion: 2
guid: 2a1d17f91519347e0a8692e2816b7c8b
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

108
com.unity.ml-agents/Editor/GridSensorComponentEditor.cs


using UnityEditor;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Editor
{
[CustomEditor(typeof(GridSensorComponent))]
[CanEditMultipleObjects]
internal class GridSensorComponentEditor : UnityEditor.Editor
{
public override void OnInspectorGUI()
{
#if !MLA_UNITY_PHYSICS_MODULE
EditorGUILayout.HelpBox("The Physics Module is not currently present. " +
"Please add it to your project in order to use the GridSensor APIs in the " +
$"{nameof(GridSensorComponent)}", MessageType.Warning);
#endif
var so = serializedObject;
so.Update();
// Drawing the GridSensorComponent
EditorGUI.BeginChangeCheck();
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
// These fields affect the sensor order or observation size,
// So can't be changed at runtime.
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_SensorName)), true);
EditorGUILayout.LabelField("Grid Settings", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_CellScale)), true);
// We only supports 2D GridSensor now so lock gridSize.y to 1
var gridSize = so.FindProperty(nameof(GridSensorComponent.m_GridSize));
var gridSize2d = new Vector3Int(gridSize.vector3IntValue.x, 1, gridSize.vector3IntValue.z);
var newGridSize = EditorGUILayout.Vector3IntField("Grid Size", gridSize2d);
gridSize.vector3IntValue = new Vector3Int(newGridSize.x, 1, newGridSize.z);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_RotateWithAgent)), true);
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
// detectable tags
var detectableTags = so.FindProperty(nameof(GridSensorComponent.m_DetectableTags));
var newSize = EditorGUILayout.IntField("Detectable Tags", detectableTags.arraySize);
if (newSize != detectableTags.arraySize)
{
detectableTags.arraySize = newSize;
}
EditorGUI.indentLevel++;
for (var i = 0; i < detectableTags.arraySize; i++)
{
var objectTag = detectableTags.GetArrayElementAtIndex(i);
EditorGUILayout.PropertyField(objectTag, new GUIContent("Tag " + i), true);
}
EditorGUI.indentLevel--;
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ColliderMask)), true);
EditorGUILayout.LabelField("Sensor Settings", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ObservationStacks)), true);
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_CompressionType)), true);
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
EditorGUILayout.LabelField("Collider and Buffer", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_InitialColliderBufferSize)), true);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_MaxColliderBufferSize)), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.LabelField("Debug Gizmo", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ShowGizmos)), true);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_GizmoYOffset)), true);
// detectable objects
var debugColors = so.FindProperty(nameof(GridSensorComponent.m_DebugColors));
var detectableObjectSize = so.FindProperty(nameof(GridSensorComponent.m_DetectableTags)).arraySize;
if (detectableObjectSize != debugColors.arraySize)
{
debugColors.arraySize = detectableObjectSize;
}
EditorGUILayout.LabelField("Debug Colors");
EditorGUI.indentLevel++;
for (var i = 0; i < debugColors.arraySize; i++)
{
var debugColor = debugColors.GetArrayElementAtIndex(i);
EditorGUILayout.PropertyField(debugColor, new GUIContent("Tag " + i + " Color"), true);
}
EditorGUI.indentLevel--;
var requireSensorUpdate = EditorGUI.EndChangeCheck();
so.ApplyModifiedProperties();
if (requireSensorUpdate)
{
UpdateSensor();
}
}
void UpdateSensor()
{
var sensorComponent = serializedObject.targetObject as GridSensorComponent;
sensorComponent?.UpdateSensor();
}
}
}

267
com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs


using System;
using UnityEngine;
namespace Unity.MLAgents.Sensors
{
internal class BoxOverlapChecker
{
Vector3 m_CellScale;
Vector3Int m_GridSize;
bool m_RotateWithAgent;
LayerMask m_ColliderMask;
GameObject m_RootReference;
string[] m_DetectableTags;
int m_InitialColliderBufferSize;
int m_MaxColliderBufferSize;
int m_NumCells;
Vector3 m_HalfCellScale;
Vector3 m_CellCenterOffset;
Vector3[] m_CellLocalPositions;
#if MLA_UNITY_PHYSICS_MODULE
Collider[] m_ColliderBuffer;
public event Action<GameObject, int> GridOverlapDetectedAll;
public event Action<GameObject, int> GridOverlapDetectedClosest;
public event Action<GameObject, int> GridOverlapDetectedDebug;
#endif
public BoxOverlapChecker(
Vector3 cellScale,
Vector3Int gridSize,
bool rotateWithAgent,
LayerMask colliderMask,
GameObject rootReference,
string[] detectableTags,
int initialColliderBufferSize,
int maxColliderBufferSize)
{
m_CellScale = cellScale;
m_GridSize = gridSize;
m_RotateWithAgent = rotateWithAgent;
m_ColliderMask = colliderMask;
m_RootReference = rootReference;
m_DetectableTags = detectableTags;
m_InitialColliderBufferSize = initialColliderBufferSize;
m_MaxColliderBufferSize = maxColliderBufferSize;
m_NumCells = gridSize.x * gridSize.z;
m_HalfCellScale = new Vector3(cellScale.x / 2f, cellScale.y, cellScale.z / 2f);
m_CellCenterOffset = new Vector3((gridSize.x - 1f) / 2, 0, (gridSize.z - 1f) / 2);
#if MLA_UNITY_PHYSICS_MODULE
m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_InitialColliderBufferSize)];
#endif
InitCellLocalPositions();
}
public bool RotateWithAgent
{
get { return m_RotateWithAgent; }
set { m_RotateWithAgent = value; }
}
public LayerMask ColliderMask
{
get { return m_ColliderMask; }
set { m_ColliderMask = value; }
}
/// <summary>
/// Initializes the local location of the cells
/// </summary>
void InitCellLocalPositions()
{
m_CellLocalPositions = new Vector3[m_NumCells];
for (int i = 0; i < m_NumCells; i++)
{
m_CellLocalPositions[i] = GetCellLocalPosition(i);
}
}
/// <summary>Converts the index of the cell to the 3D point (y is zero) relative to grid center</summary>
/// <returns>Vector3 of the position of the center of the cell relative to grid center</returns>
/// <param name="cell">The index of the cell</param>
Vector3 GetCellLocalPosition(int cellIndex)
{
float x = (cellIndex / m_GridSize.z - m_CellCenterOffset.x) * m_CellScale.x;
float z = (cellIndex % m_GridSize.z - m_CellCenterOffset.z) * m_CellScale.z;
return new Vector3(x, 0, z);
}
internal Vector3 GetCellGlobalPosition(int cellIndex)
{
if (m_RotateWithAgent)
{
return m_RootReference.transform.TransformPoint(m_CellLocalPositions[cellIndex]);
}
else
{
return m_CellLocalPositions[cellIndex] + m_RootReference.transform.position;
}
}
internal Quaternion GetGridRotation()
{
return m_RotateWithAgent ? m_RootReference.transform.rotation : Quaternion.identity;
}
/// <summary>
/// Perceive the latest grid status. Call OverlapBoxNonAlloc once to detect colliders.
/// Then parse the collider arrays according to all available gridSensor delegates.
/// </summary>
internal void Update()
{
#if MLA_UNITY_PHYSICS_MODULE
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)
{
var cellCenter = GetCellGlobalPosition(cellIndex);
var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation());
if (GridOverlapDetectedAll != null)
{
ParseCollidersAll(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedAll);
}
if (GridOverlapDetectedClosest != null)
{
ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedClosest);
}
}
#endif
}
/// <summary>
/// Same as Update(), but only load data for debug gizmo.
/// </summary>
internal void UpdateGizmo()
{
#if MLA_UNITY_PHYSICS_MODULE
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)
{
var cellCenter = GetCellGlobalPosition(cellIndex);
var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation());
ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedDebug);
}
#endif
}
#if MLA_UNITY_PHYSICS_MODULE
/// <summary>
/// This method attempts to perform the Physics.OverlapBoxNonAlloc and will double the size of the Collider buffer
/// if the number of Colliders in the buffer after the call is equal to the length of the buffer.
/// </summary>
/// <param name="cellCenter"></param>
/// <param name="halfCellScale"></param>
/// <param name="rotation"></param>
/// <returns></returns>
int BufferResizingOverlapBoxNonAlloc(Vector3 cellCenter, Vector3 halfCellScale, Quaternion rotation)
{
int numFound;
// Since we can only get a fixed number of results, requery
// until we're sure we can hold them all (or until we hit the max size).
while (true)
{
numFound = Physics.OverlapBoxNonAlloc(cellCenter, halfCellScale, m_ColliderBuffer, rotation, m_ColliderMask);
if (numFound == m_ColliderBuffer.Length && m_ColliderBuffer.Length < m_MaxColliderBufferSize)
{
m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_ColliderBuffer.Length * 2)];
m_InitialColliderBufferSize = m_ColliderBuffer.Length;
}
else
{
break;
}
}
return numFound;
}
/// <summary>
/// Parses the array of colliders found within a cell. Finds the closest gameobject to the agent root reference within the cell
/// </summary>
void ParseCollidersClosest(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action<GameObject, int> detectedAction)
{
GameObject closestColliderGo = null;
var minDistanceSquared = float.MaxValue;
for (var i = 0; i < numFound; i++)
{
var currentColliderGo = foundColliders[i].gameObject;
// Continue if the current collider go is the root reference
if (ReferenceEquals(currentColliderGo, m_RootReference))
{
continue;
}
var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter);
var currentDistanceSquared = (closestColliderPoint - m_RootReference.transform.position).sqrMagnitude;
if (currentDistanceSquared >= minDistanceSquared)
{
continue;
}
// Checks if our colliders contain a detectable object
var index = -1;
for (var ii = 0; ii < m_DetectableTags.Length; ii++)
{
if (currentColliderGo.CompareTag(m_DetectableTags[ii]))
{
index = ii;
break;
}
}
if (index > -1 && currentDistanceSquared < minDistanceSquared)
{
minDistanceSquared = currentDistanceSquared;
closestColliderGo = currentColliderGo;
}
}
if (!ReferenceEquals(closestColliderGo, null))
{
detectedAction.Invoke(closestColliderGo, cellIndex);
}
}
/// <summary>
/// Parses all colliders in the array of colliders found within a cell.
/// </summary>
void ParseCollidersAll(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action<GameObject, int> detectedAction)
{
for (int i = 0; i < numFound; i++)
{
var currentColliderGo = foundColliders[i].gameObject;
if (!ReferenceEquals(currentColliderGo, m_RootReference))
{
detectedAction.Invoke(currentColliderGo, cellIndex);
}
}
}
#endif
internal void RegisterSensor(GridSensorBase sensor)
{
#if MLA_UNITY_PHYSICS_MODULE
if (sensor.GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders)
{
GridOverlapDetectedAll += sensor.ProcessDetectedObject;
}
else
{
GridOverlapDetectedClosest += sensor.ProcessDetectedObject;
}
#endif
}
internal void RegisterDebugSensor(GridSensorBase debugSensor)
{
#if MLA_UNITY_PHYSICS_MODULE
GridOverlapDetectedDebug += debugSensor.ProcessDetectedObject;
#endif
}
}
}

293
com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs


using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A SensorComponent that creates a <see cref="GridSensor"/>.
/// </summary>
[AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)]
public class GridSensorComponent : SensorComponent
{
// dummy sensor only used for debug gizmo
GridSensorBase m_DebugSensor;
List<ISensor> m_Sensors;
internal BoxOverlapChecker m_BoxOverlapChecker;
[HideInInspector, SerializeField]
protected internal string m_SensorName = "GridSensor";
/// <summary>
/// Name of the generated <see cref="GridSensor"/> object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName
{
get { return m_SensorName; }
set { m_SensorName = value; }
}
[HideInInspector, SerializeField]
internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f);
/// <summary>
/// The scale of each grid cell.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public Vector3 CellScale
{
get { return m_CellScale; }
set { m_CellScale = value; }
}
[HideInInspector, SerializeField]
internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16);
/// <summary>
/// The number of grid on each side.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public Vector3Int GridSize
{
get { return m_GridSize; }
set
{
if (value.y != 1)
{
m_GridSize = new Vector3Int(value.x, 1, value.z);
}
else
{
m_GridSize = value;
}
}
}
[HideInInspector, SerializeField]
internal bool m_RotateWithAgent = true;
/// <summary>
/// Rotate the grid based on the direction the agent is facing.
/// </summary>
public bool RotateWithAgent
{
get { return m_RotateWithAgent; }
set { m_RotateWithAgent = value; }
}
[HideInInspector, SerializeField]
internal string[] m_DetectableTags;
/// <summary>
/// List of tags that are detected.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public string[] DetectableTags
{
get { return m_DetectableTags; }
set { m_DetectableTags = value; }
}
[HideInInspector, SerializeField]
internal LayerMask m_ColliderMask;
/// <summary>
/// The layer mask.
/// </summary>
public LayerMask ColliderMask
{
get { return m_ColliderMask; }
set { m_ColliderMask = value; }
}
[HideInInspector, SerializeField]
internal int m_MaxColliderBufferSize = 500;
/// <summary>
/// The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words
/// the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int MaxColliderBufferSize
{
get { return m_MaxColliderBufferSize; }
set { m_MaxColliderBufferSize = value; }
}
[HideInInspector, SerializeField]
internal int m_InitialColliderBufferSize = 4;
/// <summary>
/// The Estimated Max Number of Colliders to expect per cell. This number is used to
/// pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc
/// Physics API. If the number of colliders found is >= InitialColliderBufferSize the array
/// will be resized to double its current size. The hard coded absolute size is 500.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int InitialColliderBufferSize
{
get { return m_InitialColliderBufferSize; }
set { m_InitialColliderBufferSize = value; }
}
[HideInInspector, SerializeField]
internal Color[] m_DebugColors;
/// <summary>
/// Array of Colors used for the grid gizmos.
/// </summary>
public Color[] DebugColors
{
get { return m_DebugColors; }
set { m_DebugColors = value; }
}
[HideInInspector, SerializeField]
internal float m_GizmoYOffset = 0f;
/// <summary>
/// The height of the gizmos grid.
/// </summary>
public float GizmoYOffset
{
get { return m_GizmoYOffset; }
set { m_GizmoYOffset = value; }
}
[HideInInspector, SerializeField]
internal bool m_ShowGizmos = false;
/// <summary>
/// Whether to show gizmos or not.
/// </summary>
public bool ShowGizmos
{
get { return m_ShowGizmos; }
set { m_ShowGizmos = value; }
}
[HideInInspector, SerializeField]
internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG;
/// <summary>
/// The compression type to use for the sensor.
/// </summary>
public SensorCompressionType CompressionType
{
get { return m_CompressionType; }
set { m_CompressionType = value; UpdateSensor(); }
}
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")]
internal int m_ObservationStacks = 1;
/// <summary>
/// Whether to stack previous observations. Using 1 means no previous observations.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int ObservationStacks
{
get { return m_ObservationStacks; }
set { m_ObservationStacks = value; }
}
/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
List<ISensor> m_Sensors = new List<ISensor>();
m_BoxOverlapChecker = new BoxOverlapChecker(
m_CellScale,
m_GridSize,
m_RotateWithAgent,
m_ColliderMask,
gameObject,
m_DetectableTags,
m_InitialColliderBufferSize,
m_MaxColliderBufferSize
);
// debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None.
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor);
var gridSensors = GetGridSensors();
if (gridSensors == null || gridSensors.Length < 1)
{
throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." +
"If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor.");
}
foreach (var sensor in gridSensors)
{
if (ObservationStacks != 1)
{
m_Sensors.Add(new StackingSensor(sensor, ObservationStacks));
}
else
{
m_Sensors.Add(sensor);
}
m_BoxOverlapChecker.RegisterSensor(sensor);
}
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
((GridSensorBase)m_Sensors[0]).m_BoxOverlapChecker = m_BoxOverlapChecker;
return m_Sensors.ToArray();
}
/// <summary>
/// Get an array of GridSensors to be added in this component.
/// Override this method and return custom GridSensor implementations.
/// </summary>
/// <returns>Array of grid sensors to be added to the component.</returns>
protected virtual GridSensorBase[] GetGridSensors()
{
List<GridSensorBase> sensorList = new List<GridSensorBase>();
var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType);
sensorList.Add(sensor);
return sensorList.ToArray();
}
/// <summary>
/// Update fields that are safe to change on the Sensor at runtime.
/// </summary>
internal void UpdateSensor()
{
if (m_Sensors != null)
{
m_BoxOverlapChecker.RotateWithAgent = m_RotateWithAgent;
m_BoxOverlapChecker.ColliderMask = m_ColliderMask;
foreach (var sensor in m_Sensors)
{
((GridSensorBase)sensor).CompressionType = m_CompressionType;
}
}
}
void OnDrawGizmos()
{
if (m_ShowGizmos)
{
if (m_BoxOverlapChecker == null || m_DebugSensor == null)
{
return;
}
m_DebugSensor.ResetPerceptionBuffer();
m_BoxOverlapChecker.UpdateGizmo();
var cellColors = m_DebugSensor.PerceptionBuffer;
var rotation = m_BoxOverlapChecker.GetGridRotation();
var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z);
var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0);
var oldGizmoMatrix = Gizmos.matrix;
for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++)
{
var cellPosition = m_BoxOverlapChecker.GetCellGlobalPosition(i);
var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale);
Gizmos.matrix = oldGizmoMatrix * cubeTransform;
var colorIndex = cellColors[i] - 1;
var debugRayColor = Color.white;
if (colorIndex > -1 && m_DebugColors.Length > colorIndex)
{
debugRayColor = m_DebugColors[(int)colorIndex];
}
Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
Gizmos.DrawCube(Vector3.zero, Vector3.one);
}
Gizmos.matrix = oldGizmoMatrix;
}
}
}
}

1001
com.unity.ml-agents.extensions/Documentation~/images/gridobs-vs-vectorobs.gif
文件差异内容过多而无法显示
查看文件

20
com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-camera.png

之前 之后

94
com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-gridsensor.png

之前 之后

67
com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-raycast.png

之前 之后

79
com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example.p