浏览代码
Move GridSensor into main package (#5256)
Move GridSensor into main package (#5256)
* move OneHotGridSensor into main package * changelog and migration guide * remove old doc * check if physics module presents/check-for-ModelOverriders
GitHub
4 年前
当前提交
2a9c8f0d
共有 47 个文件被更改,包括 928 次插入 和 3439 次删除
-
60Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockAgentGridCollab.prefab
-
141Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockCollabAreaGrid.prefab
-
1com.unity.ml-agents.extensions/Documentation~/com.unity.ml-agents.extensions.md
-
6com.unity.ml-agents.extensions/Runtime/Sensors/CountingGridSensor.cs
-
5com.unity.ml-agents/CHANGELOG.md
-
2com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
-
4com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
-
28docs/Migrating.md
-
2com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta
-
28com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs
-
9com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs
-
8com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs
-
2com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs
-
37com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs
-
24com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs
-
143com.unity.ml-agents.extensions/Tests/Runtime/Sensors/CountingGridSensorTests.cs
-
11com.unity.ml-agents.extensions/Tests/Runtime/Sensors/CountingGridSensorTests.cs.meta
-
108com.unity.ml-agents/Editor/GridSensorComponentEditor.cs
-
267com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs
-
293com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs
-
1001com.unity.ml-agents.extensions/Documentation~/images/gridobs-vs-vectorobs.gif
-
20com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-camera.png
-
94com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-gridsensor.png
-
67com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example-raycast.png
-
79com.unity.ml-agents.extensions/Documentation~/images/gridsensor-example.png
-
1001com.unity.ml-agents.extensions/Documentation~/images/gridsensor-debug.png
-
230com.unity.ml-agents.extensions/Documentation~/Grid-Sensor.md
-
106com.unity.ml-agents.extensions/Editor/GridSensorComponentEditor.cs
-
254com.unity.ml-agents.extensions/Runtime/Sensors/BoxOverlapChecker.cs
-
328com.unity.ml-agents.extensions/Runtime/Sensors/GridSensorComponent.cs
-
8com.unity.ml-agents.extensions/Tests/Editor/GridSensors.meta
-
0/com.unity.ml-agents/Editor/GridSensorComponentEditor.cs.meta
-
0/com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs.meta
-
0/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs
-
0/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs
-
0/com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs.meta
-
0/com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs.meta
-
0/com.unity.ml-agents/Runtime/Sensors/OneHotGridSensor.cs.meta
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTestUtils.cs.meta
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/SimpleTestGridSensor.cs.meta
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs.meta
-
0/com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs.meta
|
|||
using System; |
|||
using System.Collections; |
|||
using System.Linq; |
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using UnityEngine.TestTools; |
|||
using Unity.MLAgents.Sensors; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
using Object = UnityEngine.Object; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
public class CountingGridSensorTests |
|||
{ |
|||
GameObject testGo; |
|||
GameObject boxGo; |
|||
TestCountingGridSensorComponent gridSensorComponent; |
|||
|
|||
// Use built-in tags
|
|||
const string k_Tag1 = "Player"; |
|||
const string k_Tag2 = "Respawn"; |
|||
|
|||
[UnitySetUp] |
|||
public IEnumerator SetupScene() |
|||
{ |
|||
testGo = new GameObject("test"); |
|||
testGo.transform.position = Vector3.zero; |
|||
gridSensorComponent = testGo.AddComponent<TestCountingGridSensorComponent>(); |
|||
|
|||
boxGo = new GameObject("block"); |
|||
boxGo.tag = k_Tag1; |
|||
boxGo.transform.position = new Vector3(3f, 0f, 3f); |
|||
boxGo.AddComponent<BoxCollider>(); |
|||
|
|||
yield return null; |
|||
} |
|||
|
|||
[TearDown] |
|||
public void ClearScene() |
|||
{ |
|||
Object.DestroyImmediate(boxGo); |
|||
Object.DestroyImmediate(testGo); |
|||
} |
|||
|
|||
public class TestCountingGridSensorComponent : GridSensorComponent |
|||
{ |
|||
public void SetParameters(string[] detectableTags) |
|||
{ |
|||
DetectableTags = detectableTags; |
|||
CellScale = new Vector3(1, 0.01f, 1); |
|||
GridSize = new Vector3Int(10, 1, 10); |
|||
ColliderMask = LayerMask.GetMask("Default"); |
|||
RotateWithAgent = false; |
|||
CompressionType = SensorCompressionType.None; |
|||
} |
|||
|
|||
protected override GridSensorBase[] GetGridSensors() |
|||
{ |
|||
return new GridSensorBase[] { |
|||
new CountingGridSensor( |
|||
"TestSensor", |
|||
CellScale, |
|||
GridSize, |
|||
DetectableTags, |
|||
CompressionType) }; |
|||
} |
|||
} |
|||
|
|||
// Copied from GridSensorTests in main package
|
|||
public static float[][] DuplicateArray(float[] array, int numCopies) |
|||
{ |
|||
float[][] duplicated = new float[numCopies][]; |
|||
for (int i = 0; i < numCopies; i++) |
|||
{ |
|||
duplicated[i] = array; |
|||
} |
|||
return duplicated; |
|||
} |
|||
|
|||
// Copied from GridSensorTests in main package
|
|||
public static void AssertSubarraysAtIndex(float[] total, int[] indicies, float[][] expectedArrays, float[] expectedDefaultArray) |
|||
{ |
|||
int totalIndex = 0; |
|||
int subIndex = 0; |
|||
int subarrayIndex = 0; |
|||
int lenOfData = expectedDefaultArray.Length; |
|||
int numArrays = total.Length / lenOfData; |
|||
for (int i = 0; i < numArrays; i++) |
|||
{ |
|||
totalIndex = i * lenOfData; |
|||
|
|||
if (indicies.Contains(i)) |
|||
{ |
|||
subarrayIndex = Array.IndexOf(indicies, i); |
|||
for (subIndex = 0; subIndex < lenOfData; subIndex++) |
|||
{ |
|||
Assert.AreEqual(expectedArrays[subarrayIndex][subIndex], total[totalIndex], |
|||
"Expected " + expectedArrays[subarrayIndex][subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]); |
|||
totalIndex++; |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
for (subIndex = 0; subIndex < lenOfData; subIndex++) |
|||
{ |
|||
Assert.AreEqual(expectedDefaultArray[subIndex], total[totalIndex], |
|||
"Expected default value " + expectedDefaultArray[subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]); |
|||
totalIndex++; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestCountingSensor() |
|||
{ |
|||
string[] tags = { k_Tag1, k_Tag2 }; |
|||
gridSensorComponent.SetParameters(tags); |
|||
var gridSensor = (CountingGridSensor)gridSensorComponent.CreateSensors()[0]; |
|||
Assert.AreEqual(gridSensor.PerceptionBuffer.Length, 10 * 10 * 2); |
|||
|
|||
gridSensor.Update(); |
|||
|
|||
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 }; |
|||
float[][] expectedSubarrays = DuplicateArray(new float[] { 1, 0 }, 4); |
|||
float[] expectedDefault = new float[] { 0, 0 }; |
|||
AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault); |
|||
|
|||
var boxGo2 = new GameObject("block"); |
|||
boxGo2.tag = k_Tag1; |
|||
boxGo2.transform.position = new Vector3(3.1f, 0f, 3f); |
|||
boxGo2.AddComponent<BoxCollider>(); |
|||
|
|||
gridSensor.Update(); |
|||
|
|||
subarrayIndicies = new int[] { 77, 78, 87, 88 }; |
|||
expectedSubarrays = DuplicateArray(new float[] { 2, 0 }, 4); |
|||
expectedDefault = new float[] { 0, 0 }; |
|||
AssertSubarraysAtIndex(gridSensor.PerceptionBuffer, subarrayIndicies, expectedSubarrays, expectedDefault); |
|||
Object.DestroyImmediate(boxGo2); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 2a1d17f91519347e0a8692e2816b7c8b |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using UnityEditor; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Editor |
|||
{ |
|||
[CustomEditor(typeof(GridSensorComponent))] |
|||
[CanEditMultipleObjects] |
|||
internal class GridSensorComponentEditor : UnityEditor.Editor |
|||
{ |
|||
public override void OnInspectorGUI() |
|||
{ |
|||
#if !MLA_UNITY_PHYSICS_MODULE
|
|||
EditorGUILayout.HelpBox("The Physics Module is not currently present. " + |
|||
"Please add it to your project in order to use the GridSensor APIs in the " + |
|||
$"{nameof(GridSensorComponent)}", MessageType.Warning); |
|||
#endif
|
|||
|
|||
var so = serializedObject; |
|||
so.Update(); |
|||
|
|||
// Drawing the GridSensorComponent
|
|||
EditorGUI.BeginChangeCheck(); |
|||
|
|||
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); |
|||
{ |
|||
// These fields affect the sensor order or observation size,
|
|||
// So can't be changed at runtime.
|
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_SensorName)), true); |
|||
|
|||
EditorGUILayout.LabelField("Grid Settings", EditorStyles.boldLabel); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_CellScale)), true); |
|||
// We only supports 2D GridSensor now so lock gridSize.y to 1
|
|||
var gridSize = so.FindProperty(nameof(GridSensorComponent.m_GridSize)); |
|||
var gridSize2d = new Vector3Int(gridSize.vector3IntValue.x, 1, gridSize.vector3IntValue.z); |
|||
var newGridSize = EditorGUILayout.Vector3IntField("Grid Size", gridSize2d); |
|||
gridSize.vector3IntValue = new Vector3Int(newGridSize.x, 1, newGridSize.z); |
|||
} |
|||
EditorGUI.EndDisabledGroup(); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_RotateWithAgent)), true); |
|||
|
|||
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); |
|||
{ |
|||
// detectable tags
|
|||
var detectableTags = so.FindProperty(nameof(GridSensorComponent.m_DetectableTags)); |
|||
var newSize = EditorGUILayout.IntField("Detectable Tags", detectableTags.arraySize); |
|||
if (newSize != detectableTags.arraySize) |
|||
{ |
|||
detectableTags.arraySize = newSize; |
|||
} |
|||
EditorGUI.indentLevel++; |
|||
for (var i = 0; i < detectableTags.arraySize; i++) |
|||
{ |
|||
var objectTag = detectableTags.GetArrayElementAtIndex(i); |
|||
EditorGUILayout.PropertyField(objectTag, new GUIContent("Tag " + i), true); |
|||
} |
|||
EditorGUI.indentLevel--; |
|||
} |
|||
EditorGUI.EndDisabledGroup(); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ColliderMask)), true); |
|||
EditorGUILayout.LabelField("Sensor Settings", EditorStyles.boldLabel); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ObservationStacks)), true); |
|||
EditorGUI.EndDisabledGroup(); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_CompressionType)), true); |
|||
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties()); |
|||
{ |
|||
EditorGUILayout.LabelField("Collider and Buffer", EditorStyles.boldLabel); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_InitialColliderBufferSize)), true); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_MaxColliderBufferSize)), true); |
|||
} |
|||
EditorGUI.EndDisabledGroup(); |
|||
|
|||
EditorGUILayout.LabelField("Debug Gizmo", EditorStyles.boldLabel); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ShowGizmos)), true); |
|||
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_GizmoYOffset)), true); |
|||
|
|||
// detectable objects
|
|||
var debugColors = so.FindProperty(nameof(GridSensorComponent.m_DebugColors)); |
|||
var detectableObjectSize = so.FindProperty(nameof(GridSensorComponent.m_DetectableTags)).arraySize; |
|||
if (detectableObjectSize != debugColors.arraySize) |
|||
{ |
|||
debugColors.arraySize = detectableObjectSize; |
|||
} |
|||
EditorGUILayout.LabelField("Debug Colors"); |
|||
EditorGUI.indentLevel++; |
|||
for (var i = 0; i < debugColors.arraySize; i++) |
|||
{ |
|||
var debugColor = debugColors.GetArrayElementAtIndex(i); |
|||
EditorGUILayout.PropertyField(debugColor, new GUIContent("Tag " + i + " Color"), true); |
|||
} |
|||
EditorGUI.indentLevel--; |
|||
|
|||
var requireSensorUpdate = EditorGUI.EndChangeCheck(); |
|||
so.ApplyModifiedProperties(); |
|||
|
|||
if (requireSensorUpdate) |
|||
{ |
|||
UpdateSensor(); |
|||
} |
|||
} |
|||
|
|||
void UpdateSensor() |
|||
{ |
|||
var sensorComponent = serializedObject.targetObject as GridSensorComponent; |
|||
sensorComponent?.UpdateSensor(); |
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
internal class BoxOverlapChecker |
|||
{ |
|||
Vector3 m_CellScale; |
|||
Vector3Int m_GridSize; |
|||
bool m_RotateWithAgent; |
|||
LayerMask m_ColliderMask; |
|||
GameObject m_RootReference; |
|||
string[] m_DetectableTags; |
|||
int m_InitialColliderBufferSize; |
|||
int m_MaxColliderBufferSize; |
|||
|
|||
int m_NumCells; |
|||
Vector3 m_HalfCellScale; |
|||
Vector3 m_CellCenterOffset; |
|||
Vector3[] m_CellLocalPositions; |
|||
|
|||
#if MLA_UNITY_PHYSICS_MODULE
|
|||
Collider[] m_ColliderBuffer; |
|||
|
|||
public event Action<GameObject, int> GridOverlapDetectedAll; |
|||
public event Action<GameObject, int> GridOverlapDetectedClosest; |
|||
public event Action<GameObject, int> GridOverlapDetectedDebug; |
|||
#endif
|
|||
|
|||
public BoxOverlapChecker( |
|||
Vector3 cellScale, |
|||
Vector3Int gridSize, |
|||
bool rotateWithAgent, |
|||
LayerMask colliderMask, |
|||
GameObject rootReference, |
|||
string[] detectableTags, |
|||
int initialColliderBufferSize, |
|||
int maxColliderBufferSize) |
|||
{ |
|||
m_CellScale = cellScale; |
|||
m_GridSize = gridSize; |
|||
m_RotateWithAgent = rotateWithAgent; |
|||
m_ColliderMask = colliderMask; |
|||
m_RootReference = rootReference; |
|||
m_DetectableTags = detectableTags; |
|||
m_InitialColliderBufferSize = initialColliderBufferSize; |
|||
m_MaxColliderBufferSize = maxColliderBufferSize; |
|||
|
|||
m_NumCells = gridSize.x * gridSize.z; |
|||
m_HalfCellScale = new Vector3(cellScale.x / 2f, cellScale.y, cellScale.z / 2f); |
|||
m_CellCenterOffset = new Vector3((gridSize.x - 1f) / 2, 0, (gridSize.z - 1f) / 2); |
|||
#if MLA_UNITY_PHYSICS_MODULE
|
|||
m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_InitialColliderBufferSize)]; |
|||
#endif
|
|||
|
|||
InitCellLocalPositions(); |
|||
} |
|||
|
|||
public bool RotateWithAgent |
|||
{ |
|||
get { return m_RotateWithAgent; } |
|||
set { m_RotateWithAgent = value; } |
|||
} |
|||
|
|||
public LayerMask ColliderMask |
|||
{ |
|||
get { return m_ColliderMask; } |
|||
set { m_ColliderMask = value; } |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Initializes the local location of the cells
|
|||
/// </summary>
|
|||
void InitCellLocalPositions() |
|||
{ |
|||
m_CellLocalPositions = new Vector3[m_NumCells]; |
|||
|
|||
for (int i = 0; i < m_NumCells; i++) |
|||
{ |
|||
m_CellLocalPositions[i] = GetCellLocalPosition(i); |
|||
} |
|||
} |
|||
|
|||
/// <summary>Converts the index of the cell to the 3D point (y is zero) relative to grid center</summary>
|
|||
/// <returns>Vector3 of the position of the center of the cell relative to grid center</returns>
|
|||
/// <param name="cell">The index of the cell</param>
|
|||
Vector3 GetCellLocalPosition(int cellIndex) |
|||
{ |
|||
float x = (cellIndex / m_GridSize.z - m_CellCenterOffset.x) * m_CellScale.x; |
|||
float z = (cellIndex % m_GridSize.z - m_CellCenterOffset.z) * m_CellScale.z; |
|||
return new Vector3(x, 0, z); |
|||
} |
|||
|
|||
internal Vector3 GetCellGlobalPosition(int cellIndex) |
|||
{ |
|||
if (m_RotateWithAgent) |
|||
{ |
|||
return m_RootReference.transform.TransformPoint(m_CellLocalPositions[cellIndex]); |
|||
} |
|||
else |
|||
{ |
|||
return m_CellLocalPositions[cellIndex] + m_RootReference.transform.position; |
|||
} |
|||
} |
|||
|
|||
internal Quaternion GetGridRotation() |
|||
{ |
|||
return m_RotateWithAgent ? m_RootReference.transform.rotation : Quaternion.identity; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Perceive the latest grid status. Call OverlapBoxNonAlloc once to detect colliders.
|
|||
/// Then parse the collider arrays according to all available gridSensor delegates.
|
|||
/// </summary>
|
|||
internal void Update() |
|||
{ |
|||
#if MLA_UNITY_PHYSICS_MODULE
|
|||
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++) |
|||
{ |
|||
var cellCenter = GetCellGlobalPosition(cellIndex); |
|||
var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation()); |
|||
|
|||
if (GridOverlapDetectedAll != null) |
|||
{ |
|||
ParseCollidersAll(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedAll); |
|||
} |
|||
if (GridOverlapDetectedClosest != null) |
|||
{ |
|||
ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedClosest); |
|||
} |
|||
} |
|||
#endif
|
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Same as Update(), but only load data for debug gizmo.
|
|||
/// </summary>
|
|||
internal void UpdateGizmo() |
|||
{ |
|||
#if MLA_UNITY_PHYSICS_MODULE
|
|||
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++) |
|||
{ |
|||
var cellCenter = GetCellGlobalPosition(cellIndex); |
|||
var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, m_HalfCellScale, GetGridRotation()); |
|||
|
|||
ParseCollidersClosest(m_ColliderBuffer, numFound, cellIndex, cellCenter, GridOverlapDetectedDebug); |
|||
} |
|||
#endif
|
|||
} |
|||
|
|||
#if MLA_UNITY_PHYSICS_MODULE
|
|||
/// <summary>
|
|||
/// This method attempts to perform the Physics.OverlapBoxNonAlloc and will double the size of the Collider buffer
|
|||
/// if the number of Colliders in the buffer after the call is equal to the length of the buffer.
|
|||
/// </summary>
|
|||
/// <param name="cellCenter"></param>
|
|||
/// <param name="halfCellScale"></param>
|
|||
/// <param name="rotation"></param>
|
|||
/// <returns></returns>
|
|||
int BufferResizingOverlapBoxNonAlloc(Vector3 cellCenter, Vector3 halfCellScale, Quaternion rotation) |
|||
{ |
|||
int numFound; |
|||
// Since we can only get a fixed number of results, requery
|
|||
// until we're sure we can hold them all (or until we hit the max size).
|
|||
while (true) |
|||
{ |
|||
numFound = Physics.OverlapBoxNonAlloc(cellCenter, halfCellScale, m_ColliderBuffer, rotation, m_ColliderMask); |
|||
if (numFound == m_ColliderBuffer.Length && m_ColliderBuffer.Length < m_MaxColliderBufferSize) |
|||
{ |
|||
m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_ColliderBuffer.Length * 2)]; |
|||
m_InitialColliderBufferSize = m_ColliderBuffer.Length; |
|||
} |
|||
else |
|||
{ |
|||
break; |
|||
} |
|||
} |
|||
return numFound; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Parses the array of colliders found within a cell. Finds the closest gameobject to the agent root reference within the cell
|
|||
/// </summary>
|
|||
void ParseCollidersClosest(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action<GameObject, int> detectedAction) |
|||
{ |
|||
GameObject closestColliderGo = null; |
|||
var minDistanceSquared = float.MaxValue; |
|||
|
|||
for (var i = 0; i < numFound; i++) |
|||
{ |
|||
var currentColliderGo = foundColliders[i].gameObject; |
|||
|
|||
// Continue if the current collider go is the root reference
|
|||
if (ReferenceEquals(currentColliderGo, m_RootReference)) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter); |
|||
var currentDistanceSquared = (closestColliderPoint - m_RootReference.transform.position).sqrMagnitude; |
|||
|
|||
if (currentDistanceSquared >= minDistanceSquared) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
// Checks if our colliders contain a detectable object
|
|||
var index = -1; |
|||
for (var ii = 0; ii < m_DetectableTags.Length; ii++) |
|||
{ |
|||
if (currentColliderGo.CompareTag(m_DetectableTags[ii])) |
|||
{ |
|||
index = ii; |
|||
break; |
|||
} |
|||
} |
|||
if (index > -1 && currentDistanceSquared < minDistanceSquared) |
|||
{ |
|||
minDistanceSquared = currentDistanceSquared; |
|||
closestColliderGo = currentColliderGo; |
|||
} |
|||
} |
|||
|
|||
if (!ReferenceEquals(closestColliderGo, null)) |
|||
{ |
|||
detectedAction.Invoke(closestColliderGo, cellIndex); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Parses all colliders in the array of colliders found within a cell.
|
|||
/// </summary>
|
|||
void ParseCollidersAll(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter, Action<GameObject, int> detectedAction) |
|||
{ |
|||
for (int i = 0; i < numFound; i++) |
|||
{ |
|||
var currentColliderGo = foundColliders[i].gameObject; |
|||
if (!ReferenceEquals(currentColliderGo, m_RootReference)) |
|||
{ |
|||
detectedAction.Invoke(currentColliderGo, cellIndex); |
|||
} |
|||
} |
|||
} |
|||
#endif
|
|||
|
|||
internal void RegisterSensor(GridSensorBase sensor) |
|||
{ |
|||
#if MLA_UNITY_PHYSICS_MODULE
|
|||
if (sensor.GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders) |
|||
{ |
|||
GridOverlapDetectedAll += sensor.ProcessDetectedObject; |
|||
} |
|||
else |
|||
{ |
|||
GridOverlapDetectedClosest += sensor.ProcessDetectedObject; |
|||
} |
|||
#endif
|
|||
} |
|||
|
|||
internal void RegisterDebugSensor(GridSensorBase debugSensor) |
|||
{ |
|||
#if MLA_UNITY_PHYSICS_MODULE
|
|||
GridOverlapDetectedDebug += debugSensor.ProcessDetectedObject; |
|||
#endif
|
|||
} |
|||
} |
|||
} |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// A SensorComponent that creates a <see cref="GridSensor"/>.
|
|||
/// </summary>
|
|||
[AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)] |
|||
public class GridSensorComponent : SensorComponent |
|||
{ |
|||
// dummy sensor only used for debug gizmo
|
|||
GridSensorBase m_DebugSensor; |
|||
List<ISensor> m_Sensors; |
|||
internal BoxOverlapChecker m_BoxOverlapChecker; |
|||
|
|||
[HideInInspector, SerializeField] |
|||
protected internal string m_SensorName = "GridSensor"; |
|||
/// <summary>
|
|||
/// Name of the generated <see cref="GridSensor"/> object.
|
|||
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
|
|||
/// </summary>
|
|||
public string SensorName |
|||
{ |
|||
get { return m_SensorName; } |
|||
set { m_SensorName = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f); |
|||
|
|||
/// <summary>
|
|||
/// The scale of each grid cell.
|
|||
/// Note that changing this after the sensor is created has no effect.
|
|||
/// </summary>
|
|||
public Vector3 CellScale |
|||
{ |
|||
get { return m_CellScale; } |
|||
set { m_CellScale = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16); |
|||
/// <summary>
|
|||
/// The number of grid on each side.
|
|||
/// Note that changing this after the sensor is created has no effect.
|
|||
/// </summary>
|
|||
public Vector3Int GridSize |
|||
{ |
|||
get { return m_GridSize; } |
|||
set |
|||
{ |
|||
if (value.y != 1) |
|||
{ |
|||
m_GridSize = new Vector3Int(value.x, 1, value.z); |
|||
} |
|||
else |
|||
{ |
|||
m_GridSize = value; |
|||
} |
|||
} |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal bool m_RotateWithAgent = true; |
|||
/// <summary>
|
|||
/// Rotate the grid based on the direction the agent is facing.
|
|||
/// </summary>
|
|||
public bool RotateWithAgent |
|||
{ |
|||
get { return m_RotateWithAgent; } |
|||
set { m_RotateWithAgent = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal string[] m_DetectableTags; |
|||
/// <summary>
|
|||
/// List of tags that are detected.
|
|||
/// Note that changing this after the sensor is created has no effect.
|
|||
/// </summary>
|
|||
public string[] DetectableTags |
|||
{ |
|||
get { return m_DetectableTags; } |
|||
set { m_DetectableTags = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal LayerMask m_ColliderMask; |
|||
/// <summary>
|
|||
/// The layer mask.
|
|||
/// </summary>
|
|||
public LayerMask ColliderMask |
|||
{ |
|||
get { return m_ColliderMask; } |
|||
set { m_ColliderMask = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal int m_MaxColliderBufferSize = 500; |
|||
/// <summary>
|
|||
/// The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words
|
|||
/// the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell.
|
|||
/// Note that changing this after the sensor is created has no effect.
|
|||
/// </summary>
|
|||
public int MaxColliderBufferSize |
|||
{ |
|||
get { return m_MaxColliderBufferSize; } |
|||
set { m_MaxColliderBufferSize = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal int m_InitialColliderBufferSize = 4; |
|||
/// <summary>
|
|||
/// The Estimated Max Number of Colliders to expect per cell. This number is used to
|
|||
/// pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc
|
|||
/// Physics API. If the number of colliders found is >= InitialColliderBufferSize the array
|
|||
/// will be resized to double its current size. The hard coded absolute size is 500.
|
|||
/// Note that changing this after the sensor is created has no effect.
|
|||
/// </summary>
|
|||
public int InitialColliderBufferSize |
|||
{ |
|||
get { return m_InitialColliderBufferSize; } |
|||
set { m_InitialColliderBufferSize = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal Color[] m_DebugColors; |
|||
/// <summary>
|
|||
/// Array of Colors used for the grid gizmos.
|
|||
/// </summary>
|
|||
public Color[] DebugColors |
|||
{ |
|||
get { return m_DebugColors; } |
|||
set { m_DebugColors = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal float m_GizmoYOffset = 0f; |
|||
/// <summary>
|
|||
/// The height of the gizmos grid.
|
|||
/// </summary>
|
|||
public float GizmoYOffset |
|||
{ |
|||
get { return m_GizmoYOffset; } |
|||
set { m_GizmoYOffset = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal bool m_ShowGizmos = false; |
|||
/// <summary>
|
|||
/// Whether to show gizmos or not.
|
|||
/// </summary>
|
|||
public bool ShowGizmos |
|||
{ |
|||
get { return m_ShowGizmos; } |
|||
set { m_ShowGizmos = value; } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG; |
|||
/// <summary>
|
|||
/// The compression type to use for the sensor.
|
|||
/// </summary>
|
|||
public SensorCompressionType CompressionType |
|||
{ |
|||
get { return m_CompressionType; } |
|||
set { m_CompressionType = value; UpdateSensor(); } |
|||
} |
|||
|
|||
[HideInInspector, SerializeField] |
|||
[Range(1, 50)] |
|||
[Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")] |
|||
internal int m_ObservationStacks = 1; |
|||
/// <summary>
|
|||
/// Whether to stack previous observations. Using 1 means no previous observations.
|
|||
/// Note that changing this after the sensor is created has no effect.
|
|||
/// </summary>
|
|||
public int ObservationStacks |
|||
{ |
|||
get { return m_ObservationStacks; } |
|||
set { m_ObservationStacks = value; } |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override ISensor[] CreateSensors() |
|||
{ |
|||
List<ISensor> m_Sensors = new List<ISensor>(); |
|||
m_BoxOverlapChecker = new BoxOverlapChecker( |
|||
m_CellScale, |
|||
m_GridSize, |
|||
m_RotateWithAgent, |
|||
m_ColliderMask, |
|||
gameObject, |
|||
m_DetectableTags, |
|||
m_InitialColliderBufferSize, |
|||
m_MaxColliderBufferSize |
|||
); |
|||
|
|||
// debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None.
|
|||
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None); |
|||
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor); |
|||
|
|||
var gridSensors = GetGridSensors(); |
|||
if (gridSensors == null || gridSensors.Length < 1) |
|||
{ |
|||
throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." + |
|||
"If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor."); |
|||
} |
|||
|
|||
foreach (var sensor in gridSensors) |
|||
{ |
|||
if (ObservationStacks != 1) |
|||
{ |
|||
m_Sensors.Add(new StackingSensor(sensor, ObservationStacks)); |
|||
} |
|||
else |
|||
{ |
|||
m_Sensors.Add(sensor); |
|||
} |
|||
m_BoxOverlapChecker.RegisterSensor(sensor); |
|||
} |
|||
|
|||
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
|
|||
((GridSensorBase)m_Sensors[0]).m_BoxOverlapChecker = m_BoxOverlapChecker; |
|||
return m_Sensors.ToArray(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get an array of GridSensors to be added in this component.
|
|||
/// Override this method and return custom GridSensor implementations.
|
|||
/// </summary>
|
|||
/// <returns>Array of grid sensors to be added to the component.</returns>
|
|||
protected virtual GridSensorBase[] GetGridSensors() |
|||
{ |
|||
List<GridSensorBase> sensorList = new List<GridSensorBase>(); |
|||
var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType); |
|||
sensorList.Add(sensor); |
|||
return sensorList.ToArray(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Update fields that are safe to change on the Sensor at runtime.
|
|||
/// </summary>
|
|||
internal void UpdateSensor() |
|||
{ |
|||
if (m_Sensors != null) |
|||
{ |
|||
m_BoxOverlapChecker.RotateWithAgent = m_RotateWithAgent; |
|||
m_BoxOverlapChecker.ColliderMask = m_ColliderMask; |
|||
foreach (var sensor in m_Sensors) |
|||
{ |
|||
((GridSensorBase)sensor).CompressionType = m_CompressionType; |
|||
} |
|||
} |
|||
} |
|||
|
|||
void OnDrawGizmos() |
|||
{ |
|||
if (m_ShowGizmos) |
|||
{ |
|||
if (m_BoxOverlapChecker == null || m_DebugSensor == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
m_DebugSensor.ResetPerceptionBuffer(); |
|||
m_BoxOverlapChecker.UpdateGizmo(); |
|||
var cellColors = m_DebugSensor.PerceptionBuffer; |
|||
var rotation = m_BoxOverlapChecker.GetGridRotation(); |
|||
|
|||
var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z); |
|||
var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0); |
|||
var oldGizmoMatrix = Gizmos.matrix; |
|||
for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++) |
|||
{ |
|||
var cellPosition = m_BoxOverlapChecker.GetCellGlobalPosition(i); |
|||
var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale); |
|||
Gizmos.matrix = oldGizmoMatrix * cubeTransform; |
|||
var colorIndex = cellColors[i] - 1; |
|||
var debugRayColor = Color.white; |
|||
if (colorIndex > -1 && m_DebugColors.Length > colorIndex) |
|||
{ |
|||
debugRayColor = m_DebugColors[(int)colorIndex]; |
|||
} |
|||
Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); |
|||
Gizmos.DrawCube(Vector3.zero, Vector3.one); |
|||
} |
|||
|
|||
Gizmos.matrix = oldGizmoMatrix; |
|||
} |
|||
} |
|||
} |
|||
} |
1001
com.unity.ml-agents.extensions/Documentation~/images/gridobs-vs-vectorobs.gif
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件