浏览代码

Make OverlapChecker an interface (#5324)

/colab-links
GitHub 4 年前
当前提交
fac11fa7
共有 7 个文件被更改,包括 107 次插入41 次删除
  1. 29
      com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs
  2. 6
      com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs
  3. 22
      com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs
  4. 12
      com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs
  5. 6
      com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs
  6. 62
      com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs
  7. 11
      com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs.meta

29
com.unity.ml-agents/Runtime/Sensors/BoxOverlapChecker.cs


namespace Unity.MLAgents.Sensors
{
internal class BoxOverlapChecker
/// <summary>
/// The grid perception strategy that uses box overlap to detect objects.
/// </summary>
internal class BoxOverlapChecker : IGridPerception
{
Vector3 m_CellScale;
Vector3Int m_GridSize;

}
}
/// <summary>Converts the index of the cell to the 3D point (y is zero) relative to grid center</summary>
/// <returns>Vector3 of the position of the center of the cell relative to grid center</returns>
/// <param name="cellIndex">The index of the cell</param>
Vector3 GetCellLocalPosition(int cellIndex)
public Vector3 GetCellLocalPosition(int cellIndex)
{
float x = (cellIndex / m_GridSize.z - m_CellCenterOffset.x) * m_CellScale.x;
float z = (cellIndex % m_GridSize.z - m_CellCenterOffset.z) * m_CellScale.z;

internal Vector3 GetCellGlobalPosition(int cellIndex)
public Vector3 GetCellGlobalPosition(int cellIndex)
{
if (m_RotateWithAgent)
{

}
}
internal Quaternion GetGridRotation()
public Quaternion GetGridRotation()
/// <summary>
/// Perceive the latest grid status. Call OverlapBoxNonAlloc once to detect colliders.
/// Then parse the collider arrays according to all available gridSensor delegates.
/// </summary>
internal void Update()
public void Perceive()
{
#if MLA_UNITY_PHYSICS_MODULE
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)

#endif
}
/// <summary>
/// Same as Update(), but only load data for debug gizmo.
/// </summary>
internal void UpdateGizmo()
public void UpdateGizmo()
{
#if MLA_UNITY_PHYSICS_MODULE
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)

}
#endif
internal void RegisterSensor(GridSensorBase sensor)
public void RegisterSensor(GridSensorBase sensor)
{
#if MLA_UNITY_PHYSICS_MODULE
if (sensor.GetProcessCollidersMethod() == ProcessCollidersMethod.ProcessAllColliders)

#endif
}
internal void RegisterDebugSensor(GridSensorBase debugSensor)
public void RegisterDebugSensor(GridSensorBase debugSensor)
{
#if MLA_UNITY_PHYSICS_MODULE
GridOverlapDetectedDebug += debugSensor.ProcessDetectedObject;

6
com.unity.ml-agents/Runtime/Sensors/GridSensorBase.cs


string[] m_DetectableTags;
SensorCompressionType m_CompressionType;
ObservationSpec m_ObservationSpec;
internal BoxOverlapChecker m_BoxOverlapChecker;
internal IGridPerception m_GridPerception;
// Buffers
float[] m_PerceptionBuffer;

ResetPerceptionBuffer();
using (TimerStack.Instance.Scoped("GridSensor.Update"))
{
if (m_BoxOverlapChecker != null)
if (m_GridPerception != null)
m_BoxOverlapChecker.Update();
m_GridPerception.Perceive();
}
}
}

22
com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs


// dummy sensor only used for debug gizmo
GridSensorBase m_DebugSensor;
List<GridSensorBase> m_Sensors;
internal BoxOverlapChecker m_BoxOverlapChecker;
internal IGridPerception m_GridPerception;
[HideInInspector, SerializeField]
protected internal string m_SensorName = "GridSensor";

/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
m_BoxOverlapChecker = new BoxOverlapChecker(
m_GridPerception = new BoxOverlapChecker(
m_CellScale,
m_GridSize,
m_RotateWithAgent,

// debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None.
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor);
m_GridPerception.RegisterDebugSensor(m_DebugSensor);
m_Sensors = GetGridSensors().ToList();
if (m_Sensors == null || m_Sensors.Count < 1)

}
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
m_Sensors[0].m_BoxOverlapChecker = m_BoxOverlapChecker;
m_Sensors[0].m_GridPerception = m_GridPerception;
m_BoxOverlapChecker.RegisterSensor(sensor);
m_GridPerception.RegisterSensor(sensor);
}
if (ObservationStacks != 1)

{
if (m_Sensors != null)
{
m_BoxOverlapChecker.RotateWithAgent = m_RotateWithAgent;
m_BoxOverlapChecker.ColliderMask = m_ColliderMask;
m_GridPerception.RotateWithAgent = m_RotateWithAgent;
m_GridPerception.ColliderMask = m_ColliderMask;
foreach (var sensor in m_Sensors)
{
sensor.CompressionType = m_CompressionType;

{
if (m_ShowGizmos)
{
if (m_BoxOverlapChecker == null || m_DebugSensor == null)
if (m_GridPerception == null || m_DebugSensor == null)
m_BoxOverlapChecker.UpdateGizmo();
m_GridPerception.UpdateGizmo();
var rotation = m_BoxOverlapChecker.GetGridRotation();
var rotation = m_GridPerception.GetGridRotation();
var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z);
var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0);

var cellPosition = m_BoxOverlapChecker.GetCellGlobalPosition(i);
var cellPosition = m_GridPerception.GetCellGlobalPosition(i);
var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale);
Gizmos.matrix = oldGizmoMatrix * cubeTransform;
var colorIndex = cellColors[i] - 1;

12
com.unity.ml-agents/Tests/Runtime/Sensor/BoxOverlapCheckerTests.cs


testGo.transform.position = Vector3.zero;
testObjects.Add(testGo);
var boxOverlap = TestBoxOverlapChecker.CreateChecker(agentGameObject: testGo, centerObject: testGo, initialColliderBufferSize: 2, maxColliderBufferSize: 5);
boxOverlap.Update();
boxOverlap.Perceive();
Assert.AreEqual(2, boxOverlap.ColliderBuffer.Length);
for (var i = 0; i < 3; i++)

boxGo.AddComponent<BoxCollider>();
testObjects.Add(boxGo);
}
boxOverlap.Update();
boxOverlap.Perceive();
Assert.AreEqual(4, boxOverlap.ColliderBuffer.Length);
for (var i = 0; i < 2; i++)

boxGo.AddComponent<BoxCollider>();
testObjects.Add(boxGo);
}
boxOverlap.Update();
boxOverlap.Perceive();
Assert.AreEqual(5, boxOverlap.ColliderBuffer.Length);
Object.DestroyImmediate(testGo);

testObjects.Add(boxGo);
}
boxOverlap.Update();
boxOverlap.Perceive();
helper.Verify(1, new List<GameObject> { testObjects[0] });
Object.DestroyImmediate(testGo);

testObjects.Add(boxGo);
}
boxOverlap.Update();
boxOverlap.Perceive();
helper.Verify(3, testObjects);
Object.DestroyImmediate(testGo);

foreach (var sensor in sensors)
{
var gridsensor = (GridSensorBase)sensor;
if (gridsensor.m_BoxOverlapChecker != null)
if (gridsensor.m_GridPerception != null)
{
numChecker += 1;
}

6
com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs


string[] tags = { k_Tag1, k_Tag2 };
gridSensorComponent.SetComponentParameters(tags, useOneHotTag: true, useGridSensorBase: true, useTestingGridSensor: true);
var gridSensors = gridSensorComponent.CreateSensors();
Assert.IsNotNull(((GridSensorBase)gridSensors[0]).m_BoxOverlapChecker);
Assert.IsNull(((GridSensorBase)gridSensors[1]).m_BoxOverlapChecker);
Assert.IsNull(((GridSensorBase)gridSensors[2]).m_BoxOverlapChecker);
Assert.IsNotNull(((GridSensorBase)gridSensors[0]).m_GridPerception);
Assert.IsNull(((GridSensorBase)gridSensors[1]).m_GridPerception);
Assert.IsNull(((GridSensorBase)gridSensors[2]).m_GridPerception);
}
[Test]

62
com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs


using System;
using UnityEngine;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// An interface for GridSensor perception that defines the grid cells and collider detecting strategies.
/// </summary>
internal interface IGridPerception
{
bool RotateWithAgent
{
get;
set;
}
LayerMask ColliderMask
{
get;
set;
}
/// <summary>Converts the index of the cell to the 3D point (y is zero) relative to grid center</summary>
/// <returns>Vector3 of the position of the center of the cell relative to grid center</returns>
/// <param name="cellIndex">The index of the cell</param>
Vector3 GetCellLocalPosition(int cellIndex);
/// <summary>
/// Converts the index of the cell to the 3D point (y is zero) in world space
/// based on the result from GetCellLocalPosition()
/// </summary>
/// <returns>Vector3 of the position of the center of the cell in world space</returns>
/// <param name="cellIndex">The index of the cell</param>
Vector3 GetCellGlobalPosition(int cellIndex);
Quaternion GetGridRotation();
/// <summary>
/// Perceive the latest grid status. Detect colliders for each cell, parse the collider arrays,
/// then trigger registered sensors to encode and update with the new grid status.
/// </summary>
void Perceive();
/// <summary>
/// Same as Perceive(), but only load data for debug gizmo.
/// </summary>
void UpdateGizmo();
/// <summary>
/// Register a sensor to this GridPerception to receive the grid perception results.
/// When the GridPerception perceive a new observation, registered sensors will be triggered
/// to encode the new observation and update its data.
/// </summary>
void RegisterSensor(GridSensorBase sensor);
/// <summary>
/// Register an internal debug sensor.
/// Debug sensors will only be triggered when drawing debug gizmos.
/// </summary>
void RegisterDebugSensor(GridSensorBase debugSensor);
}
}

11
com.unity.ml-agents/Runtime/Sensors/IGridPerception.cs.meta


fileFormatVersion: 2
guid: 87820d9eb927c4fa483dff9289d983f1
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存