比较提交

...
此合并请求有变更与目标分支冲突。
/com.unity.ml-agents.extensions/Editor/GridSensorComponentEditor.cs
/com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs
/com.unity.ml-agents.extensions/Tests/Utils/GridObsTestComponents/SimpleTestGridSensor.cs
/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
/com.unity.ml-agents.extensions/Runtime/Sensors/GridSensorComponent.cs

4 次代码提交

作者 SHA1 备注 提交日期
Ruo-Ping Dong e763ce68 update sensor fields at runtime 4 年前
Ruo-Ping Dong 07b5c6dd more rename 4 年前
Ruo-Ping Dong 6ef766aa fix bug calculating grid local position 4 年前
Ruo-Ping Dong e81ac4b1 rename variables 4 年前
共有 5 个文件被更改,包括 177 次插入266 次删除
  1. 6
      com.unity.ml-agents.extensions/Editor/GridSensorComponentEditor.cs
  2. 8
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs
  3. 6
      com.unity.ml-agents.extensions/Tests/Utils/GridObsTestComponents/SimpleTestGridSensor.cs
  4. 389
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  5. 34
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensorComponent.cs

6
com.unity.ml-agents.extensions/Editor/GridSensorComponentEditor.cs


EditorGUILayout.LabelField("Grid Settings", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_CellScale)), true);
// We only supports 2D GridSensor now so display gridNumSide as Vector2
var gridNumSide = so.FindProperty(nameof(GridSensorComponent.m_GridNumSide));
var gridNumSide = so.FindProperty(nameof(GridSensorComponent.m_GridNum));
var gridNumSide2d = new Vector2Int(gridNumSide.vector3IntValue.x, gridNumSide.vector3IntValue.z);
var newGridNumSide = EditorGUILayout.Vector2IntField("Grid Num Side", gridNumSide2d);
gridNumSide.vector3IntValue = new Vector3Int(newGridNumSide.x, 1, newGridNumSide.y);

EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_DepthType)), true);
// channel depth
var channelDepth = so.FindProperty(nameof(GridSensorComponent.m_ChannelDepth));
var channelDepth = so.FindProperty(nameof(GridSensorComponent.m_ChannelDepths));
var newDepth = EditorGUILayout.IntField("Channel Depth", channelDepth.arraySize);
if (newDepth != channelDepth.arraySize)
{

EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_InitialColliderBufferSize)), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ObserveMask)), true);
EditorGUILayout.PropertyField(so.FindProperty(nameof(GridSensorComponent.m_ColliderMask)), true);
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
EditorGUILayout.LabelField("Sensor Settings", EditorStyles.boldLabel);

8
com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs


}
public static void SetComponentParameters(GridSensorComponent gridComponent, string[] detectableObjects, int[] channelDepth, GridDepthType gridDepthType,
float cellScaleX, float cellScaleZ, int gridWidth, int gridHeight, int observeMaskInt, bool rotateWithAgent, Color[] debugColors)
float cellScaleX, float cellScaleZ, int gridWidth, int gridHeight, int colliderMaskInt, bool rotateWithAgent, Color[] debugColors)
gridComponent.ChannelDepth = channelDepth;
gridComponent.ChannelDepths = channelDepth;
gridComponent.GridNumSide = new Vector3Int(gridWidth, 1, gridHeight);
gridComponent.ObserveMask = observeMaskInt;
gridComponent.GridNum = new Vector3Int(gridWidth, 1, gridHeight);
gridComponent.ColliderMask = colliderMaskInt;
gridComponent.RotateWithAgent = rotateWithAgent;
gridComponent.DebugColors = debugColors;
}

6
com.unity.ml-agents.extensions/Tests/Utils/GridObsTestComponents/SimpleTestGridSensor.cs


m_Sensor = new SimpleTestGridSensor(
SensorName,
CellScale,
GridNumSide,
GridNum,
ChannelDepth,
ChannelDepths,
ObserveMask,
ColliderMask,
DepthType,
RootReference,
CompressionType,

389
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


/// </summary>
public class GridSensor : ISensor, IBuiltInSensor
{
/// <summary>
/// Name of this grid sensor.
/// </summary>
string Name;
//
// Main Parameters
//
/// <summary>
/// The scale of each grid cell.
/// </summary>
Vector3 CellScale;
/// <summary>
/// The number of grid on each side.
/// </summary>
Vector3Int GridNumSide;
/// <summary>
/// Rotate the grid based on the direction the agent is facing.
/// </summary>
bool RotateWithAgent;
/// <summary>
/// Array holding the depth of each channel.
/// </summary>
int[] ChannelDepth;
/// <summary>
/// List of tags that are detected.
/// </summary>
string[] DetectableObjects;
/// <summary>
/// The layer mask.
/// </summary>
LayerMask ObserveMask;
/// <summary>
/// The data layout that the grid should output.
/// </summary>
GridDepthType gridDepthType = GridDepthType.Channel;
/// <summary>
/// The reference of the root of the agent. This is used to disambiguate objects with the same tag as the agent. Defaults to current GameObject.
/// </summary>
GameObject rootReference;
int MaxColliderBufferSize;
int InitialColliderBufferSize;
Collider[] m_ColliderBuffer;
float[] m_ChannelBuffer;
string m_Name;
Vector3 m_CellScale;
Vector3Int m_GridNum;
bool m_RotateWithAgent;
GameObject m_RootReference;
int m_MaxColliderBufferSize;
int m_InitialColliderBufferSize;
LayerMask m_ColliderMask;
GridDepthType m_GridDepthType;
int[] m_ChannelDepths;
string[] m_DetectableObjects;
SensorCompressionType m_CompressionType;
ObservationSpec m_ObservationSpec;
//
// Hidden Parameters
//
/// <summary>
/// The total number of observations per cell of the grid. Its equivalent to the "channel" on the outgoing tensor.
/// </summary>
int ObservationPerCell;
/// <summary>
/// The offsets used to specify where within a cell's allotted data, certain observations will be inserted.
/// </summary>
int[] ChannelOffsets;
/// <summary>
/// The main storage of perceptual information.
/// </summary>
// Buffers
/// <summary>
/// The default value of the perceptionBuffer when using the ChannelHot DepthType. Used to reset the array/
/// </summary>
/// <summary>
/// Array of Colors needed in order to load the values of the perception buffer to a texture.
/// </summary>
Texture2D m_PerceptionTexture;
Collider[] m_ColliderBuffer;
float[] m_CellDataBuffer;
int[] m_ChannelOffsets;
Vector3[] m_CellLocalPositions;
int[] m_GizmoColorIndexes;
Vector3[] m_CellGlobalPosition;
/// <summary>
/// Texture where the colors are written to so that they can be compressed in PNG format.
/// </summary>
Texture2D m_perceptionTexture2D;
//
//
int m_NumCells;
int m_CellObservationSize;
float m_InverseSphereRadius;
/// <summary>
/// Radius of grid, used for normalizing the distance.
/// </summary>
float InverseSphereRadius;
/// <summary>
/// Total Number of cells (width*height)
/// </summary>
int NumCells;
/// <summary>
/// Offset used for calculating CellToPoint
/// </summary>
float OffsetGridNumSide = 7.5f; // (gridNumSideZ - 1) / 2;
/// <summary>
/// Cached ObservationSpec
/// </summary>
ObservationSpec m_ObservationSpec;
//
// Debug Parameters
//
SensorCompressionType m_CompressionType = SensorCompressionType.PNG;
/// <summary>
/// Array of colors displaying the DebugColors for each cell in OnDrawGizmos. Only updated if ShowGizmos.
/// </summary>
int[] m_CellActivity;
/// <summary>
/// Array of global positions where each position is the center of a cell.
/// </summary>
Vector3[] m_GizmoCellPosition;
/// <summary>
/// Array of local positions where each position is the center of a cell.
/// </summary>
Vector3[] CellPoints;
Vector3Int gridNumSide,
Vector3Int gridNum,
int[] channelDepth,
int[] channelDepths,
LayerMask observeMask,
LayerMask colliderMask,
GameObject root,
GameObject rootReference,
Name = name;
CellScale = cellScale;
GridNumSide = gridNumSide;
if (GridNumSide.y != 1)
m_Name = name;
m_CellScale = cellScale;
m_GridNum = gridNum;
m_RotateWithAgent = rotateWithAgent;
m_RootReference = rootReference;
m_MaxColliderBufferSize = maxColliderBufferSize;
m_InitialColliderBufferSize = initialColliderBufferSize;
m_ColliderMask = colliderMask;
m_GridDepthType = depthType;
m_ChannelDepths = channelDepths;
m_DetectableObjects = detectableObjects;
m_CompressionType = compression;
if (m_GridNum.y != 1)
RotateWithAgent = rotateWithAgent;
ChannelDepth = channelDepth;
DetectableObjects = detectableObjects;
ObserveMask = observeMask;
gridDepthType = depthType;
rootReference = root;
CompressionType = compression;
MaxColliderBufferSize = maxColliderBufferSize;
InitialColliderBufferSize = initialColliderBufferSize;
if (gridDepthType == GridDepthType.Counting && DetectableObjects.Length != ChannelDepth.Length)
if (m_GridDepthType == GridDepthType.Counting && m_DetectableObjects.Length != m_ChannelDepths.Length)
{
throw new UnityAgentsException("The channels of a CountingGridSensor is equal to the number of detectableObjects");
}

InitCellPoints();
ResetPerceptionBuffer();
m_ObservationSpec = ObservationSpec.Visual(GridNumSide.x, GridNumSide.z, ObservationPerCell);
m_perceptionTexture2D = new Texture2D(GridNumSide.x, GridNumSide.z, TextureFormat.RGB24, false);
m_ColliderBuffer = new Collider[Math.Min(MaxColliderBufferSize, InitialColliderBufferSize)];
m_ObservationSpec = ObservationSpec.Visual(m_GridNum.x, m_GridNum.z, m_CellObservationSize);
m_PerceptionTexture = new Texture2D(m_GridNum.x, m_GridNum.z, TextureFormat.RGB24, false);
m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_InitialColliderBufferSize)];
}
public SensorCompressionType CompressionType

}
public int[] CellActivity
public bool RotateWithAgent
{
get { return m_RotateWithAgent; }
set { m_RotateWithAgent = value; }
}
public LayerMask ColliderMask
{
get { return m_ColliderMask; }
set { m_ColliderMask = value; }
}
public int[] GizmoColorIndexes
get { return m_CellActivity; }
get { return m_GizmoColorIndexes; }
}
/// <summary>

{
NumCells = GridNumSide.x * GridNumSide.z;
float sphereRadiusX = (CellScale.x * GridNumSide.x) / Mathf.Sqrt(2);
float sphereRadiusZ = (CellScale.z * GridNumSide.z) / Mathf.Sqrt(2);
InverseSphereRadius = 1.0f / Mathf.Max(sphereRadiusX, sphereRadiusZ);
OffsetGridNumSide = (GridNumSide.z - 1f) / 2f;
m_NumCells = m_GridNum.x * m_GridNum.z;
float sphereRadiusX = (m_CellScale.x * m_GridNum.x) / Mathf.Sqrt(2);
float sphereRadiusZ = (m_CellScale.z * m_GridNum.z) / Mathf.Sqrt(2);
m_InverseSphereRadius = 1.0f / Mathf.Max(sphereRadiusX, sphereRadiusZ);
}
/// <summary>

void InitDepthType()
{
if (gridDepthType == GridDepthType.ChannelHot)
if (m_GridDepthType == GridDepthType.ChannelHot)
ObservationPerCell = ChannelDepth.Sum();
m_CellObservationSize = m_ChannelDepths.Sum();
ChannelOffsets = new int[ChannelDepth.Length];
for (int i = 1; i < ChannelDepth.Length; i++)
m_ChannelOffsets = new int[m_ChannelDepths.Length];
for (int i = 1; i < m_ChannelDepths.Length; i++)
ChannelOffsets[i] = ChannelOffsets[i - 1] + ChannelDepth[i - 1];
m_ChannelOffsets[i] = m_ChannelOffsets[i - 1] + m_ChannelDepths[i - 1];
m_ChannelHotDefaultPerceptionBuffer = new float[ObservationPerCell];
for (int i = 0; i < ChannelDepth.Length; i++)
m_ChannelHotDefaultPerceptionBuffer = new float[m_CellObservationSize];
for (int i = 0; i < m_ChannelDepths.Length; i++)
if (ChannelDepth[i] > 1)
if (m_ChannelDepths[i] > 1)
m_ChannelHotDefaultPerceptionBuffer[ChannelOffsets[i]] = 1;
m_ChannelHotDefaultPerceptionBuffer[m_ChannelOffsets[i]] = 1;
ObservationPerCell = ChannelDepth.Length;
m_CellObservationSize = m_ChannelDepths.Length;
Assert.IsTrue(ObservationPerCell < (255 * 3), "The maximum number of channels per cell must be less than 255 * 3");
Assert.IsTrue(m_CellObservationSize < (255 * 3), "The maximum number of channels per cell must be less than 255 * 3");
}
/// <summary>

{
CellPoints = new Vector3[NumCells];
m_CellLocalPositions = new Vector3[m_NumCells];
for (int i = 0; i < NumCells; i++)
for (int i = 0; i < m_NumCells; i++)
CellPoints[i] = CellToLocalPosition(i);
m_CellLocalPositions[i] = CellToLocalPoint(i);
}
}

{
if (m_PerceptionBuffer != null)
{
if (gridDepthType == GridDepthType.ChannelHot)
if (m_GridDepthType == GridDepthType.ChannelHot)
for (int i = 0; i < NumCells; i++)
for (int i = 0; i < m_NumCells; i++)
Array.Copy(m_ChannelHotDefaultPerceptionBuffer, 0, m_PerceptionBuffer, i * ObservationPerCell, ObservationPerCell);
Array.Copy(m_ChannelHotDefaultPerceptionBuffer, 0, m_PerceptionBuffer, i * m_CellObservationSize, m_CellObservationSize);
}
}
else

}
else
{
m_PerceptionBuffer = new float[ObservationPerCell * NumCells];
m_ColliderBuffer = new Collider[Math.Min(MaxColliderBufferSize, InitialColliderBufferSize)];
m_ChannelBuffer = new float[ChannelDepth.Length];
m_PerceptionColors = new Color[NumCells];
m_GizmoCellPosition = new Vector3[NumCells];
m_PerceptionBuffer = new float[m_CellObservationSize * m_NumCells];
m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_InitialColliderBufferSize)];
m_CellDataBuffer = new float[m_ChannelDepths.Length];
m_PerceptionColors = new Color[m_NumCells];
m_CellGlobalPosition = new Vector3[m_NumCells];
}
}

if (m_CellActivity == null)
m_CellActivity = new int[NumCells];
if (m_GizmoColorIndexes == null)
m_GizmoColorIndexes = new int[m_NumCells];
for (int i = 0; i < NumCells; i++)
for (int i = 0; i < m_NumCells; i++)
m_CellActivity[i] = -1;
m_GizmoColorIndexes[i] = -1;
}
}

{
return Name;
return m_Name;
}
/// <inheritdoc/>

using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation"))
{
var allBytes = new List<byte>();
var numImages = (ObservationPerCell + 2) / 3;
var numImages = (m_CellObservationSize + 2) / 3;
ChannelsToTexture(channelIndex, Math.Min(3, ObservationPerCell - channelIndex));
allBytes.AddRange(m_perceptionTexture2D.EncodeToPNG());
ChannelsToTexture(channelIndex, Math.Min(3, m_CellObservationSize - channelIndex));
allBytes.AddRange(m_PerceptionTexture.EncodeToPNG());
}
return allBytes.ToArray();

/// <param name="numChannelsToAdd"></param>
void ChannelsToTexture(int channelIndex, int numChannelsToAdd)
{
for (int i = 0; i < NumCells; i++)
for (int i = 0; i < m_NumCells; i++)
m_PerceptionColors[i][j] = m_PerceptionBuffer[i * ObservationPerCell + channelIndex + j];
m_PerceptionColors[i][j] = m_PerceptionBuffer[i * m_CellObservationSize + channelIndex + j];
m_perceptionTexture2D.SetPixels(m_PerceptionColors);
m_PerceptionTexture.SetPixels(m_PerceptionColors);
}
/// <summary>

ResetPerceptionBuffer();
using (TimerStack.Instance.Scoped("GridSensor.Perceive"))
{
var halfCellScale = new Vector3(CellScale.x / 2f, CellScale.y, CellScale.z / 2f);
var halfCellScale = new Vector3(m_CellScale.x / 2f, m_CellScale.y, m_CellScale.z / 2f);
for (var cellIndex = 0; cellIndex < NumCells; cellIndex++)
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)
{
var cellCenter = GetCellGlobalPosition(cellIndex);
var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, halfCellScale, GetGridRotation());

if (gridDepthType == GridDepthType.Counting)
if (m_GridDepthType == GridDepthType.Counting)
{
ParseCollidersAll(m_ColliderBuffer, numFound, cellIndex, cellCenter);
}

// until we're sure we can hold them all (or until we hit the max size).
while (true)
{
numFound = Physics.OverlapBoxNonAlloc(cellCenter, halfCellScale, m_ColliderBuffer, rotation, ObserveMask);
if (numFound == m_ColliderBuffer.Length && m_ColliderBuffer.Length < MaxColliderBufferSize)
numFound = Physics.OverlapBoxNonAlloc(cellCenter, halfCellScale, m_ColliderBuffer, rotation, m_ColliderMask);
if (numFound == m_ColliderBuffer.Length && m_ColliderBuffer.Length < m_MaxColliderBufferSize)
m_ColliderBuffer = new Collider[Math.Min(MaxColliderBufferSize, m_ColliderBuffer.Length * 2)];
InitialColliderBufferSize = m_ColliderBuffer.Length;
m_ColliderBuffer = new Collider[Math.Min(m_MaxColliderBufferSize, m_ColliderBuffer.Length * 2)];
m_InitialColliderBufferSize = m_ColliderBuffer.Length;
}
else
{

var currentColliderGo = foundColliders[i].gameObject;
// Continue if the current collider go is the root reference
if (ReferenceEquals(currentColliderGo, rootReference))
if (ReferenceEquals(currentColliderGo, m_RootReference))
var currentDistanceSquared = (closestColliderPoint - rootReference.transform.position).sqrMagnitude;
var currentDistanceSquared = (closestColliderPoint - m_RootReference.transform.position).sqrMagnitude;
for (var ii = 0; ii < DetectableObjects.Length; ii++)
for (var ii = 0; ii < m_DetectableObjects.Length; ii++)
if (currentColliderGo.CompareTag(DetectableObjects[ii]))
if (currentColliderGo.CompareTag(m_DetectableObjects[ii]))
{
index = ii;
break;

if (!ReferenceEquals(closestColliderGo, null))
{
LoadObjectData(closestColliderGo, cellIndex, (float)Math.Sqrt(minDistanceSquared) * InverseSphereRadius);
LoadObjectData(closestColliderGo, cellIndex, (float)Math.Sqrt(minDistanceSquared) * m_InverseSphereRadius);
}
Profiler.EndSample();
}

currentColliderGo = foundColliders[i].gameObject;
// Continue if the current collider go is the root reference
if (currentColliderGo == rootReference)
if (currentColliderGo == m_RootReference)
Vector3.Distance(closestColliderPoint, rootReference.transform.position) * InverseSphereRadius);
Vector3.Distance(closestColliderPoint, m_RootReference.transform.position) * m_InverseSphereRadius);
}
Profiler.EndSample();
}

/// </example>
protected virtual float[] GetObjectData(GameObject currentColliderGo, float typeIndex, float normalizedDistance)
{
Array.Clear(m_ChannelBuffer, 0, m_ChannelBuffer.Length);
m_ChannelBuffer[0] = typeIndex;
return m_ChannelBuffer;
Array.Clear(m_CellDataBuffer, 0, m_CellDataBuffer.Length);
m_CellDataBuffer[0] = typeIndex;
return m_CellDataBuffer;
}
/// <summary>

if (channelValues[j] < 0)
throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be non-negative, was " + channelValues[j]);
if (channelValues[j] > ChannelDepth[j])
throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be less than ChannelDepth[" + j + "] (" + ChannelDepth[j] + "), was " + channelValues[j]);
if (channelValues[j] > m_ChannelDepths[j])
throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be less than ChannelDepth[" + j + "] (" + m_ChannelDepths[j] + "), was " + channelValues[j]);
}
}

protected virtual void LoadObjectData(GameObject currentColliderGo, int cellIndex, float normalizedDistance)
{
Profiler.BeginSample("GridSensor.LoadObjectData");
var channelHotVals = new ArraySegment<float>(m_PerceptionBuffer, cellIndex * ObservationPerCell, ObservationPerCell);
for (var i = 0; i < DetectableObjects.Length; i++)
var channelHotVals = new ArraySegment<float>(m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize);
for (var i = 0; i < m_DetectableObjects.Length; i++)
if (gridDepthType != GridDepthType.Counting)
if (m_GridDepthType != GridDepthType.Counting)
{
for (var ii = 0; ii < channelHotVals.Count; ii++)
{

if (!ReferenceEquals(currentColliderGo, null) && currentColliderGo.CompareTag(DetectableObjects[i]))
if (!ReferenceEquals(currentColliderGo, null) && currentColliderGo.CompareTag(m_DetectableObjects[i]))
{
// TODO: Create the array already then set the values using "out" in GetObjectData
// Using i+1 as the type index as "0" represents "empty"

switch (gridDepthType)
switch (m_GridDepthType)
{
case GridDepthType.Channel:
{

// Array.Copy(channelValues, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (int j = 0; j < channelValues.Length; j++)
{
channelValues[j] /= ChannelDepth[j];
channelValues[j] /= m_ChannelDepths[j];
Array.Copy(channelValues, 0, m_PerceptionBuffer, cellIndex * ObservationPerCell, ObservationPerCell);
Array.Copy(channelValues, 0, m_PerceptionBuffer, cellIndex * m_CellObservationSize, m_CellObservationSize);
break;
}

// Array.Copy(channelHotVals, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (int j = 0; j < channelValues.Length; j++)
{
if (ChannelDepth[j] > 1)
if (m_ChannelDepths[j] > 1)
m_PerceptionBuffer[channelHotVals.Offset + (int)channelValues[j] + ChannelOffsets[j]] = 1f;
m_PerceptionBuffer[channelHotVals.Offset + (int)channelValues[j] + m_ChannelOffsets[j]] = 1f;
m_PerceptionBuffer[channelHotVals.Offset + ChannelOffsets[j]] = channelValues[j];
m_PerceptionBuffer[channelHotVals.Offset + m_ChannelOffsets[j]] = channelValues[j];
}
}
break;

// The observations are "channel count" so each grid is WxHxC where C is the number of tags
// This means that each value channelValues[i] is a counter of gameobject included into grid cells
// where i is the index of the tag in DetectableObjects
int countIndex = cellIndex * ObservationPerCell + i;
m_PerceptionBuffer[countIndex] = Mathf.Min(1f, m_PerceptionBuffer[countIndex] + 1f / ChannelDepth[i]);
int countIndex = cellIndex * m_CellObservationSize + i;
m_PerceptionBuffer[countIndex] = Mathf.Min(1f, m_PerceptionBuffer[countIndex] + 1f / m_ChannelDepths[i]);
break;
}
}

/// <summary>Converts the index of the cell to the 3D point (y is zero) relative to grid center</summary>
/// <returns>Vector3 of the position of the center of the cell relative to grid center</returns>
/// <param name="cell">The index of the cell</param>
Vector3 CellToLocalPosition(int cellIndex)
Vector3 CellToLocalPoint(int cellIndex)
float x = (cellIndex % GridNumSide.z - OffsetGridNumSide) * CellScale.x;
float z = (cellIndex / GridNumSide.z - OffsetGridNumSide) * CellScale.z - (GridNumSide.z - GridNumSide.x);
float x = (cellIndex / m_GridNum.z - m_GridNum.x / 2) * m_CellScale.x;
float z = (cellIndex % m_GridNum.z - m_GridNum.z / 2) * m_CellScale.z;
if (RotateWithAgent)
if (m_RotateWithAgent)
return rootReference.transform.TransformPoint(CellPoints[cellIndex]);
return m_RootReference.transform.TransformPoint(m_CellLocalPositions[cellIndex]);
return CellPoints[cellIndex] + rootReference.transform.position;
return m_CellLocalPositions[cellIndex] + m_RootReference.transform.position;
return RotateWithAgent ? rootReference.transform.rotation : Quaternion.identity;
return m_RotateWithAgent ? m_RootReference.transform.rotation : Quaternion.identity;
}
/// <inheritdoc/>

using (TimerStack.Instance.Scoped("GridSensor.Write"))
{
int index = 0;
for (var h = GridNumSide.z - 1; h >= 0; h--)
for (var h = m_GridNum.z - 1; h >= 0; h--)
for (var w = 0; w < GridNumSide.x; w++)
for (var w = 0; w < m_GridNum.x; w++)
for (var d = 0; d < ObservationPerCell; d++)
for (var d = 0; d < m_CellObservationSize; d++)
{
writer[h, w, d] = m_PerceptionBuffer[index];
index++;

internal int[] PerceiveGizmoColor()
{
ResetGizmoBuffer();
var halfCellScale = new Vector3(CellScale.x / 2f, CellScale.y, CellScale.z / 2f);
var halfCellScale = new Vector3(m_CellScale.x / 2f, m_CellScale.y, m_CellScale.z / 2f);
for (var cellIndex = 0; cellIndex < NumCells; cellIndex++)
for (var cellIndex = 0; cellIndex < m_NumCells; cellIndex++)
{
var cellCenter = GetCellGlobalPosition(cellIndex);
var numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, halfCellScale, GetGridRotation());

for (var i = 0; i < numFound; i++)
{
var currentColliderGo = m_ColliderBuffer[i].gameObject;
if (ReferenceEquals(currentColliderGo, rootReference))
if (ReferenceEquals(currentColliderGo, m_RootReference))
var currentDistanceSquared = (closestColliderPoint - rootReference.transform.position).sqrMagnitude;
var currentDistanceSquared = (closestColliderPoint - m_RootReference.transform.position).sqrMagnitude;
for (var ii = 0; ii < DetectableObjects.Length; ii++)
for (var ii = 0; ii < m_DetectableObjects.Length; ii++)
if (currentColliderGo.CompareTag(DetectableObjects[ii]))
if (currentColliderGo.CompareTag(m_DetectableObjects[ii]))
{
index = ii;
break;

tagIndex = index;
}
}
CellActivity[cellIndex] = tagIndex;
m_GizmoColorIndexes[cellIndex] = tagIndex;
return CellActivity;
return m_GizmoColorIndexes;
for (var i = 0; i < NumCells; i++)
for (var i = 0; i < m_NumCells; i++)
m_GizmoCellPosition[i] = GetCellGlobalPosition(i);
m_CellGlobalPosition[i] = GetCellGlobalPosition(i);
return m_GizmoCellPosition;
return m_CellGlobalPosition;
}
}
}

34
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensorComponent.cs


}
[HideInInspector, SerializeField]
internal Vector3Int m_GridNumSide = new Vector3Int(16, 1, 16);
internal Vector3Int m_GridNum = new Vector3Int(16, 1, 16);
public Vector3Int GridNumSide
public Vector3Int GridNum
get { return m_GridNumSide; }
get { return m_GridNum; }
m_GridNumSide = new Vector3Int(value.x, 1, value.z);
m_GridNum = new Vector3Int(value.x, 1, value.z);
m_GridNumSide = value;
m_GridNum = value;
}
}
}

}
[HideInInspector, SerializeField]
internal int[] m_ChannelDepth = new int[] { 1 };
internal int[] m_ChannelDepths = new int[] { 1 };
public int[] ChannelDepth
public int[] ChannelDepths
get { return m_ChannelDepth; }
set { m_ChannelDepth = value; }
get { return m_ChannelDepths; }
set { m_ChannelDepths = value; }
}
[HideInInspector, SerializeField]

}
[HideInInspector, SerializeField]
internal LayerMask m_ObserveMask;
internal LayerMask m_ColliderMask;
public LayerMask ObserveMask
public LayerMask ColliderMask
get { return m_ObserveMask; }
set { m_ObserveMask = value; }
get { return m_ColliderMask; }
set { m_ColliderMask = value; }
}
[HideInInspector, SerializeField]

m_Sensor = new GridSensor(
m_SensorName,
m_CellScale,
m_GridNumSide,
m_GridNum,
m_ChannelDepth,
m_ChannelDepths,
m_ObserveMask,
m_ColliderMask,
m_DepthType,
RootReference,
m_CompressionType,

if (m_Sensor != null)
{
m_Sensor.CompressionType = m_CompressionType;
m_Sensor.RotateWithAgent = m_RotateWithAgent;
m_Sensor.ColliderMask = m_ColliderMask;
}
}

正在加载...
取消
保存