浏览代码

Jobify grid sensor

/develop/input-actuator-tanks
Christopher Goy 4 年前
当前提交
dd6caee9
共有 7 个文件被更改,包括 462 次插入270 次删除
  1. 45
      com.unity.ml-agents.extensions/Runtime/Sensors/CountingGridSensor.cs
  2. 616
      com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
  3. 3
      com.unity.ml-agents.extensions/Runtime/Unity.ML-Agents.Extensions.asmdef
  4. 37
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/CountingGridSensorPerceiveTests.cs
  5. 25
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridObservationPerceiveTests.cs
  6. 3
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs
  7. 3
      com.unity.ml-agents.extensions/package.json

45
com.unity.ml-agents.extensions/Runtime/Sensors/CountingGridSensor.cs


using System;
using Unity.Collections;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Sensors

/// </summary>
public override void InitChannelHotDefaultPerceptionBuffer()
{
m_ChannelHotDefaultPerceptionBuffer = new float[ObservationPerCell];
m_ChannelHotDefaultPerceptionBuffer = new NativeArray<float>(ObservationPerCell, Allocator.Persistent);
}
/// <inheritdoc/>

/// For each collider, calls LoadObjectData on the gameobejct
/// </summary>
/// <param name="foundColliders">The array of colliders</param>
/// <param name="numFound"></param>
/// <param name="cellIndex">The cell index the collider is in</param>
/// <param name="cellCenter">the center of the cell the collider is in</param>
protected override void ParseColliders(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter)

closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter);
LoadObjectData(currentColliderGo, cellIndex,
Vector3.Distance(closestColliderPoint, transform.position) * InverseSphereRadius);
if (m_DetectableObjectToIndex.TryGetValue(currentColliderGo.tag, out var detectableIndex))
{
LoadObjectData(currentColliderGo, cellIndex,
detectableIndex, Vector3.Distance(closestColliderPoint, transform.position) * InverseSphereRadius);
}
}
}

/// </summary>
/// <param name="currentColliderGo">the current game object</param>
/// <param name="cellIndex">the index of the cell</param>
/// <param name="detectableIndex"></param>
protected override void LoadObjectData(GameObject currentColliderGo, int cellIndex, float normalizedDistance)
protected override void LoadObjectData(GameObject currentColliderGo, int cellIndex, int detectableIndex, float normalizedDistance)
for (int i = 0; i < DetectableObjects.Length; i++)
if (ShowGizmos)
if (currentColliderGo != null && currentColliderGo.CompareTag(DetectableObjects[i]))
Color debugRayColor = Color.white;
if (DebugColors.Length > 0)
if (ShowGizmos)
{
Color debugRayColor = Color.white;
if (DebugColors.Length > 0)
{
debugRayColor = DebugColors[i];
}
CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
}
/// <remarks>
/// The observations are "channel count" so each grid is WxHxC where C is the number of tags
/// This means that each value channelValues[i] is a counter of gameobject included into grid cells where i is the index of the tag in DetectableObjects
/// </remarks>
int countIndex = cellIndex * ObservationPerCell + i;
m_PerceptionBuffer[countIndex] = Mathf.Min(1f, m_PerceptionBuffer[countIndex] + 1f / ChannelDepth[i]);
break;
debugRayColor = DebugColors[detectableIndex];
CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
// <remarks>
// The observations are "channel count" so each grid is WxHxC where C is the number of tags
// This means that each value channelValues[i] is a counter of gameobject included into grid cells where i is the index of the tag in DetectableObjects
// </remarks>
int countIndex = cellIndex * ObservationPerCell + detectableIndex;
m_PerceptionBuffer[countIndex] = Mathf.Min(1f, m_PerceptionBuffer[countIndex] + (1f / ChannelDepth[detectableIndex]));
}
}
}

616
com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs


using System;
using System.Collections.Generic;
using Unity.Collections;
using Unity.Jobs;
using UnityEngine.Jobs;
using Unity.Burst;
#if UNITY_EDITOR
using UnityEditor;
#endif
namespace Unity.MLAgents.Extensions.Sensors
{

"Physics API. If the number of colliders found is >= InitialColliderBufferSize the array " +
"will be resized to double its current size. The hard coded absolute size is 500.")]
public int InitialColliderBufferSize = 4;
Collider[] m_ColliderBuffer;
// Collider[] m_ColliderBuffer;
NativeArray<RaycastHit> m_Hits;
NativeArray<int> m_HitIndexes;
NativeArray<BoxcastCommand> m_BoxcastCommands;
protected Dictionary<string, int> m_DetectableObjectToIndex = new Dictionary<string, int>();
/// <summary>
/// The offsets used to specify where within a cell's allotted data, certain observations will be inserted.
/// </summary>
[HideInInspector]
public NativeArray<int> ChannelOffsets;
public NativeArray<float> PerceptionBuffer => m_PerceptionBuffer;
/// <summary>
/// The main storage of perceptual information.
/// </summary>
protected NativeArray<float> m_PerceptionBuffer;
/// <summary>
/// The default value of the perceptionBuffer when using the ChannelHot DepthType. Used to reset the array/
/// </summary>
protected NativeArray<float> m_ChannelHotDefaultPerceptionBuffer;
/// <summary>
/// Array of colors displaying the DebugColors for each cell in OnDrawGizmos. Only updated if ShowGizmos.
/// </summary>
protected NativeArray<Color> CellActivity;
/// <summary>
/// Array of positions where each position is the center of a cell.
/// </summary>
private NativeArray<Vector3> CellPoints;
float[] m_ChannelBuffer;

[HideInInspector]
public int NumberOfObservations;
/// <summary>
/// The offsets used to specify where within a cell's allotted data, certain observations will be inserted.
/// </summary>
[HideInInspector]
public int[] ChannelOffsets;
/// <summary>
/// The main storage of perceptual information.
/// </summary>
protected float[] m_PerceptionBuffer;
/// <summary>
/// The default value of the perceptionBuffer when using the ChannelHot DepthType. Used to reset the array/
/// </summary>
protected float[] m_ChannelHotDefaultPerceptionBuffer;
/// <summary>
/// Array of Colors needed in order to load the values of the perception buffer to a texture.

public SensorCompressionType CompressionType = SensorCompressionType.PNG;
/// <summary>
/// Array of colors displaying the DebugColors for each cell in OnDrawGizmos. Only updated if ShowGizmos.
/// </summary>
protected Color[] CellActivity;
/// <summary>
/// Array of positions where each position is the center of a cell.
/// </summary>
private Vector3[] CellPoints;
/// <summary>
/// List representing the multiple compressed images of all of the grids
/// </summary>
private List<byte[]> compressedImgs;

private List<byte[]> byteSizesBytesList;
private Color DebugDefaultColor = new Color(1f, 1f, 1f, 0.25f);
internal JobHandle m_boxcastJobHandle;
/// <inheritdoc/>
public override ISensor CreateSensor()

public void InitGridParameters()
{
NumCells = GridNumSideX * GridNumSideZ;
float sphereRadiusX = (CellScaleX * GridNumSideX) / Mathf.Sqrt(2);
float sphereRadiusZ = (CellScaleZ * GridNumSideZ) / Mathf.Sqrt(2);
var sphereRadiusX = (CellScaleX * GridNumSideX) / Mathf.Sqrt(2);
var sphereRadiusZ = (CellScaleZ * GridNumSideZ) / Mathf.Sqrt(2);
ChannelOffsets = new int[ChannelDepth.Length];
if (ChannelOffsets.IsCreated)
{
ChannelOffsets.Dispose();
}
ChannelOffsets = new NativeArray<int>(ChannelDepth.Length, Allocator.Persistent);
DiffNumSideZX = (GridNumSideZ - GridNumSideX);
OffsetGridNumSide = (GridNumSideZ - 1f) / 2f;
HalfOfGridX = CellScaleX * GridNumSideX / 2;

ObservationPerCell = 0;
ChannelOffsets[ChannelOffsets.Length - 1] = 0;
for (int i = 1; i < ChannelDepth.Length; i++)
for (var i = 1; i < ChannelDepth.Length; i++)
for (int i = 0; i < ChannelDepth.Length; i++)
for (var i = 0; i < ChannelDepth.Length; i++)
{
ObservationPerCell += ChannelDepth[i];
}

/// </summary>
private void InitCellPoints()
{
CellPoints = new Vector3[NumCells];
if (CellPoints.IsCreated)
{
CellPoints.Dispose();
}
CellPoints = new NativeArray<Vector3>(NumCells, Allocator.Persistent);
for (int i = 0; i < NumCells; i++)
for (var i = 0; i < NumCells; i++)
{
CellPoints[i] = CellToPoint(i, false);
}

/// </summary>
public virtual void InitChannelHotDefaultPerceptionBuffer()
{
m_ChannelHotDefaultPerceptionBuffer = new float[ObservationPerCell];
for (int i = 0; i < ChannelDepth.Length; i++)
if (m_ChannelHotDefaultPerceptionBuffer.IsCreated)
{
m_ChannelHotDefaultPerceptionBuffer.Dispose();
}
m_ChannelHotDefaultPerceptionBuffer = new NativeArray<float>(ObservationPerCell, Allocator.Persistent);
for (var i = 0; i < ChannelDepth.Length; i++)
{
if (ChannelDepth[i] > 1)
{

Initialized = true;
NumberOfObservations = ObservationPerCell * NumCells;
m_PerceptionBuffer = new float[NumberOfObservations];
if (m_PerceptionBuffer.IsCreated)
{
m_boxcastJobHandle.Complete();
m_PerceptionBuffer.Dispose();
}
m_PerceptionBuffer = new NativeArray<float>(NumberOfObservations, Allocator.Persistent);
if (gridDepthType == GridDepthType.ChannelHot)
{
InitChannelHotDefaultPerceptionBuffer();

else
NumImages += 1;
CellActivity = new Color[NumCells];
m_ChannelBuffer = new float[ChannelDepth.Length];
if (CellActivity.IsCreated)
{
CellActivity.Dispose();
}
CellActivity = new NativeArray<Color>(NumCells, Allocator.Persistent);
}
/// <summary>

InitDepthType();
InitCellPoints();
InitPerceptionBuffer();
m_ColliderBuffer = new Collider[Math.Min(MaxColliderBufferSize, InitialColliderBufferSize)];
if (m_Hits.IsCreated)
{
m_Hits.Dispose();
}
m_Hits = new NativeArray<RaycastHit>(NumCells, Allocator.Persistent);
if (m_HitIndexes.IsCreated)
{
m_HitIndexes.Dispose();
}
m_HitIndexes = new NativeArray<int>(NumCells, Allocator.Persistent);
if (m_BoxcastCommands.IsCreated)
{
m_BoxcastCommands.Dispose();
}
m_BoxcastCommands = new NativeArray<BoxcastCommand>(NumCells, Allocator.Persistent);
for (var i = 0; i < DetectableObjects.Length; i++)
{
m_DetectableObjectToIndex[DetectableObjects[i]] = i;
}
// Default root reference to current game object
if (rootReference == null)
rootReference = gameObject;

/// <inheritdoc cref="ISensor.Reset"/>
void ISensor.Reset() { }
[BurstCompile(CompileSynchronously = true)]
struct ClearBufferOneHotJob : IJobParallelFor
{
public NativeArray<float> PerceptionBuf;
[ReadOnly]
public NativeArray<float> ChannelHotBuf;
public int ObservationPerCell;
public void Execute(int index)
{
PerceptionBuf[index] = ChannelHotBuf[index % ObservationPerCell];
}
}
[BurstCompile(CompileSynchronously = true)]
struct ClearBufferChannelJob : IJobParallelFor
{
public NativeArray<float> PerceptionBuf;
public void Execute(int index)
{
PerceptionBuf[index] = 0;
}
}
[BurstCompile(CompileSynchronously = true)]
struct ClearCellActivityJob : IJobParallelFor
{
public NativeArray<Color> CellActivity;
public Color DebugDefault;
public void Execute(int index)
{
CellActivity[index] = DebugDefault;
}
}
/// <summary>
/// Clears the perception buffer before loading in new data. If the gridDepthType is ChannelHot, then it initializes the
/// Reset() also reinits the cell activity array (for debug)

if (m_PerceptionBuffer != null)
Profiler.BeginSample("ClearPerceptionBuffer");
for (int i = 0; i < NumCells; i++)
for (var i = 0; i < m_PerceptionBuffer.Length; i++)
Array.Copy(m_ChannelHotDefaultPerceptionBuffer, 0, m_PerceptionBuffer, i * ObservationPerCell, ObservationPerCell);
m_PerceptionBuffer[i] = m_ChannelHotDefaultPerceptionBuffer[i % ObservationPerCell];
Array.Clear(m_PerceptionBuffer, 0, m_PerceptionBuffer.Length);
for (var i = 0; i < m_PerceptionBuffer.Length; i++)
{
m_PerceptionBuffer[i] = 0;
}
}
else
{
m_PerceptionBuffer = new float[NumberOfObservations];
m_ColliderBuffer = new Collider[Math.Min(MaxColliderBufferSize, InitialColliderBufferSize)];
}
if (ShowGizmos)
{
// Ensure to init arrays if not yet assigned (for editor)
if (CellActivity == null)
CellActivity = new Color[NumCells];
// Assign the default color to the cell activities
for (int i = 0; i < NumCells; i++)
if (ShowGizmos)
CellActivity[i] = DebugDefaultColor;
// Assign the default color to the cell activities
for (var i = 0; i < NumCells; i++)
{
CellActivity[i] = DebugDefaultColor;
}
Profiler.EndSample();
}
void OnDisable()
{
m_Hits.Dispose();
m_HitIndexes.Dispose();
m_BoxcastCommands.Dispose();
m_PerceptionBuffer.Dispose();
CellActivity.Dispose();
CellPoints.Dispose();
m_ChannelHotDefaultPerceptionBuffer.Dispose();
ChannelOffsets.Dispose();
}
/// <summary>Gets the shape of the grid observation</summary>

/// <returns>byte[] containing the compressed observation of the grid observation</returns>
public byte[] GetCompressedObservation()
{
using (TimerStack.Instance.Scoped("GridSensor.GetCompressedObservation"))
UpdateBufferFromJob();
var allBytes = new List<byte>();
Profiler.BeginSample("GridSensor.GetCompressedObservation");
Perceive(); // Fill the perception buffer with observed data
var allBytes = new List<byte>();
for (int i = 0; i < NumImages - 1; i++)
for (var i = 0; i < NumImages - 1; i++)
{
ChannelsToTexture(3 * i, 3);
allBytes.AddRange(m_perceptionTexture2D.EncodeToPNG());

allBytes.AddRange(m_perceptionTexture2D.EncodeToPNG());
return allBytes.ToArray();
Profiler.EndSample();
return allBytes.ToArray();
}
/// <summary>

/// <param name="numChannelsToAdd"></param>
protected void ChannelsToTexture(int channelIndex, int numChannelsToAdd)
{
for (int i = 0; i < NumCells; i++)
for (var i = 0; i < NumCells; i++)
for (int j = 0; j < numChannelsToAdd; j++)
for (var j = 0; j < numChannelsToAdd; j++)
{
m_PerceptionColors[i][j] = m_PerceptionBuffer[i * ObservationPerCell + channelIndex + j];
}

[BurstCompile(CompileSynchronously = true)]
struct CreateBoxcastBatch : IJobParallelFor
{
public NativeArray<BoxcastCommand> Commands;
public Vector3 halfScale;
public int ObserveMask;
[ReadOnly]
public NativeArray<Vector3> CellPoints;
public Vector3 Pos;
public void Execute(int index)
{
var cellCenter = Pos + CellPoints[index];
var rotation = Quaternion.identity;
Commands[index] = new BoxcastCommand(new Vector3(cellCenter.x, cellCenter.y - 2.0f, cellCenter.z),
halfScale,
rotation,
Vector3.up,
2.0f,
ObserveMask);
}
}
[BurstCompile(CompileSynchronously = true)]
struct CreateBoxcastRotateBatch : IJobParallelFor
{
public NativeArray<BoxcastCommand> Commands;
public Vector3 halfScale;
public int ObserveMask;
[ReadOnly]
public NativeArray<Vector3> CellPoints;
[ReadOnly]
public Matrix4x4 Mat;
public Quaternion Rotation;
public void Execute(int index)
{
var cellCenter = Mat.MultiplyPoint(CellPoints[index]);
var rotation = Rotation;
Commands[index] = new BoxcastCommand(new Vector3(cellCenter.x, cellCenter.y - 2.0f, cellCenter.z),
halfScale,
rotation,
Vector3.up,
2.0f,
ObserveMask);
}
}
// ReSharper disable Unity.PerformanceAnalysis
/// <summary>
/// Perceive - Clears the buffers, calls overlap box on the actual cell (the actual perception part)
/// for all found colliders, LoadObjectData is called

public float[] Perceive()
public NativeArray<float> Perceive()
if (m_ColliderBuffer == null)
if (!m_PerceptionBuffer.IsCreated || !m_Hits.IsCreated || !m_BoxcastCommands.IsCreated)
return Array.Empty<float>();
return new NativeArray<float>();
ClearPerceptionBuffer();
using (TimerStack.Instance.Scoped("GridSensor.Perceive"))
Profiler.BeginSample("GridSensor.Perceive");
var halfCellScale = new Vector3(CellScaleX / 2f, CellScaleY, CellScaleZ / 2f);
for (var cellIndex = 0; cellIndex < NumCells; cellIndex++)
JobHandle clearHandle;
if (gridDepthType == GridDepthType.ChannelHot)
int numFound;
Vector3 cellCenter;
if (RotateToAgent)
var clearJob = new ClearBufferOneHotJob
Transform transform1;
cellCenter = (transform1 = transform).TransformPoint(CellPoints[cellIndex]);
numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, halfCellScale, transform1.rotation);
}
else
PerceptionBuf = m_PerceptionBuffer,
ChannelHotBuf = m_ChannelHotDefaultPerceptionBuffer,
ObservationPerCell = ObservationPerCell
};
clearHandle = clearJob.Schedule(NumberOfObservations, NumberOfObservations / 12);
}
else
{
var clearJob = new ClearBufferChannelJob
cellCenter = transform.position + CellPoints[cellIndex];
numFound = BufferResizingOverlapBoxNonAlloc(cellCenter, halfCellScale, Quaternion.identity);
}
PerceptionBuf = m_PerceptionBuffer
};
clearHandle = clearJob.Schedule(NumberOfObservations, NumberOfObservations / 12);
}
if (numFound > 0)
if (ShowGizmos)
{
var gizmoJob = new ClearCellActivityJob
ParseColliders(m_ColliderBuffer, numFound, cellIndex, cellCenter);
}
CellActivity = CellActivity,
DebugDefault = DebugDefaultColor
};
clearHandle = JobHandle.CombineDependencies(clearHandle, gizmoJob.Schedule(NumCells, NumCells / 12));
}
return m_PerceptionBuffer;
}
/// <summary>
/// This method attempts to perform the Physics.OverlapBoxNonAlloc and will double the size of the Collider buffer
/// if the number of Colliders in the buffer after the call is equal to the length of the buffer.
/// </summary>
/// <param name="cellCenter"></param>
/// <param name="halfCellScale"></param>
/// <param name="rotation"></param>
/// <returns></returns>
int BufferResizingOverlapBoxNonAlloc(Vector3 cellCenter, Vector3 halfCellScale, Quaternion rotation)
{
int numFound;
// Since we can only get a fixed number of results, requery
// until we're sure we can hold them all (or until we hit the max size).
while (true)
{
numFound = Physics.OverlapBoxNonAlloc(cellCenter, halfCellScale, m_ColliderBuffer, rotation, ObserveMask);
if (numFound == m_ColliderBuffer.Length && m_ColliderBuffer.Length < MaxColliderBufferSize)
var t = transform;
var halfCellScale = new Vector3(CellScaleX / 2f, CellScaleY / 2f, CellScaleZ / 2f);
JobHandle createBoxHandle;
if (RotateToAgent)
m_ColliderBuffer = new Collider[Math.Min(MaxColliderBufferSize, m_ColliderBuffer.Length * 2)];
InitialColliderBufferSize = m_ColliderBuffer.Length;
var createBoxcasts = new CreateBoxcastRotateBatch
{
Commands = m_BoxcastCommands,
halfScale = halfCellScale,
Mat = t.localToWorldMatrix,
CellPoints = CellPoints,
ObserveMask = ObserveMask,
Rotation = t.rotation
};
createBoxHandle = createBoxcasts.Schedule(NumCells, NumCells / 12, clearHandle);
break;
var createBoxcasts = new CreateBoxcastBatch
{
Commands = m_BoxcastCommands,
halfScale = halfCellScale,
Pos = t.position,
CellPoints = CellPoints,
ObserveMask = ObserveMask
};
createBoxHandle = createBoxcasts.Schedule(NumCells, NumCells / 12, clearHandle);
}
m_boxcastJobHandle = BoxcastCommand.ScheduleBatch(m_BoxcastCommands, m_Hits, NumCells / 12, createBoxHandle);
return numFound;
}
Profiler.EndSample();
return m_PerceptionBuffer;
/// <summary>
/// Parses the array of colliders found within a cell. Finds the closest gameobject to the agent root reference within the cell

GameObject closestColliderGo = null;
var minDistanceSquared = float.MaxValue;
var detectableIndex = -1;
if (ReferenceEquals(foundColliders[i], null))
{
continue;
}
var currentColliderGo = foundColliders[i].gameObject;
// Continue if the current collider go is the root reference

var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter);
var currentDistanceSquared = (closestColliderPoint - rootReference.transform.position).sqrMagnitude;
Profiler.EndSample();
var index = -1;
for (var ii = 0; ii < DetectableObjects.Length; ii++)
{
if (currentColliderGo.CompareTag(DetectableObjects[ii]))
{
index = ii;
break;
}
}
if (index > -1 && currentDistanceSquared < minDistanceSquared)
if (m_DetectableObjectToIndex.TryGetValue(currentColliderGo.tag, out detectableIndex) && currentDistanceSquared < minDistanceSquared)
{
minDistanceSquared = currentDistanceSquared;
closestColliderGo = currentColliderGo;

if (!ReferenceEquals(closestColliderGo, null))
LoadObjectData(closestColliderGo, cellIndex, (float)Math.Sqrt(minDistanceSquared) * InverseSphereRadius);
LoadObjectData(closestColliderGo, cellIndex, detectableIndex, (float)Math.Sqrt(minDistanceSquared) * InverseSphereRadius);
Profiler.EndSample();
}

/// </example>
protected virtual float[] GetObjectData(GameObject currentColliderGo, float typeIndex, float normalizedDistance)
{
if (m_ChannelBuffer == null)
{
m_ChannelBuffer = new float[ChannelDepth.Length];
}
Array.Clear(m_ChannelBuffer, 0, m_ChannelBuffer.Length);
m_ChannelBuffer[0] = typeIndex;
return m_ChannelBuffer;

/// <param name="currentColliderGo">The gameobject used for better error messages</param>
protected virtual void ValidateValues(float[] channelValues, GameObject currentColliderGo)
{
for (int j = 0; j < channelValues.Length; j++)
for (var j = 0; j < channelValues.Length; j++)
{
if (channelValues[j] < 0)
throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be non-negative, was " + channelValues[j]);

/// </summary>
/// <param name="currentColliderGo">The game object that was found colliding with a certain cell</param>
/// <param name="cellIndex">The index of the current cell</param>
/// <param name="detectableIndex">Index into the DetectableObjects array.</param>
/// the distance currentColliderGo is compared to the edge of the gridsensor</param>
protected virtual void LoadObjectData(GameObject currentColliderGo, int cellIndex, float normalizedDistance)
/// the distance currentColliderGo is compared to the edge of the gridsensor</param>
protected virtual void LoadObjectData(GameObject currentColliderGo, int cellIndex, int detectableIndex, float normalizedDistance)
var channelHotVals = new ArraySegment<float>(m_PerceptionBuffer, cellIndex * ObservationPerCell, ObservationPerCell);
for (var i = 0; i < DetectableObjects.Length; i++)
var offset = cellIndex * ObservationPerCell;
for (var ii = 0; ii < ObservationPerCell; ii++)
{
m_PerceptionBuffer[offset + ii] = 0f;
}
// TODO: Create the array already then set the values using "out" in GetObjectData
// Using i+1 as the type index as "0" represents "empty"
var channelValues = GetObjectData(currentColliderGo, (float)detectableIndex + 1, normalizedDistance);
ValidateValues(channelValues, currentColliderGo);
if (ShowGizmos)
for (var ii = 0; ii < channelHotVals.Count; ii++)
var debugRayColor = Color.white;
if (DebugColors.Length > 0)
m_PerceptionBuffer[channelHotVals.Offset + ii] = 0f;
debugRayColor = DebugColors[detectableIndex];
if (!ReferenceEquals(currentColliderGo, null) && currentColliderGo.CompareTag(DetectableObjects[i]))
{
// TODO: Create the array already then set the values using "out" in GetObjectData
// Using i+1 as the type index as "0" represents "empty"
float[] channelValues = GetObjectData(currentColliderGo, (float)i + 1, normalizedDistance);
CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
}
ValidateValues(channelValues, currentColliderGo);
if (ShowGizmos)
switch (gridDepthType)
{
case GridDepthType.Channel:
Color debugRayColor = Color.white;
if (DebugColors.Length > 0)
// The observations are "channel based" so each grid is WxHxC where C is the number of channels
// This typically means that each channel value is normalized between 0 and 1
// If channelDepth is 1, the value is assumed normalized, else the value is normalized by the channelDepth
// The channels are then stored consecutively in PerceptionBuffer.
// NOTE: This is the only grid type that uses floating point values
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams:
// channelValues = {2, 1}
// ObservationPerCell = channelValues.Length
// channelValues = {2f/5f, 1f/3f} = {.4, .33..}
// Array.Copy(channelValues, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (var j = 0; j < channelValues.Length; j++)
debugRayColor = DebugColors[i];
channelValues[j] /= ChannelDepth[j];
CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
for (var ii = 0; ii < ObservationPerCell; ii++)
{
m_PerceptionBuffer[offset + ii] = channelValues[ii];
}
break;
switch (gridDepthType)
case GridDepthType.ChannelHot:
case GridDepthType.Channel:
// The observations are "channel hot" so each grid is WxHxD where D is the sum of all of the channel depths
// The opposite of the "channel based" case, the channel values are represented as one hot vector per channel and then concatenated together
// Thus channelDepth is assumed to be greater than 1.
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams,
// channelValues = {2, 1}
// channelOffsets = {5, 3}
// ObservationPerCell = 5 + 3 = 8
// channelHotVals = {0, 0, 1, 0, 0, 0, 1, 0}
// Array.Copy(channelHotVals, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (var j = 0; j < channelValues.Length; j++)
{
if (ChannelDepth[j] > 1)
// The observations are "channel based" so each grid is WxHxC where C is the number of channels
// This typically means that each channel value is normalized between 0 and 1
// If channelDepth is 1, the value is assumed normalized, else the value is normalized by the channelDepth
// The channels are then stored consecutively in PerceptionBuffer.
// NOTE: This is the only grid type that uses floating point values
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams:
// channelValues = {2, 1}
// ObservationPerCell = channelValues.Length
// channelValues = {2f/5f, 1f/3f} = {.4, .33..}
// Array.Copy(channelValues, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (int j = 0; j < channelValues.Length; j++)
{
channelValues[j] /= ChannelDepth[j];
}
Array.Copy(channelValues, 0, m_PerceptionBuffer, cellIndex * ObservationPerCell, ObservationPerCell);
break;
m_PerceptionBuffer[offset + (int)channelValues[j] + ChannelOffsets[j]] = 1f;
case GridDepthType.ChannelHot:
else
// The observations are "channel hot" so each grid is WxHxD where D is the sum of all of the channel depths
// The opposite of the "channel based" case, the channel values are represented as one hot vector per channel and then concatenated together
// Thus channelDepth is assumed to be greater than 1.
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams,
// channelValues = {2, 1}
// channelOffsets = {5, 3}
// ObservationPerCell = 5 + 3 = 8
// channelHotVals = {0, 0, 1, 0, 0, 0, 1, 0}
// Array.Copy(channelHotVals, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (int j = 0; j < channelValues.Length; j++)
{
if (ChannelDepth[j] > 1)
{
m_PerceptionBuffer[channelHotVals.Offset + (int)channelValues[j] + ChannelOffsets[j]] = 1f;
}
else
{
m_PerceptionBuffer[channelHotVals.Offset + ChannelOffsets[j]] = channelValues[j];
}
}
break;
m_PerceptionBuffer[offset + ChannelOffsets[j]] = channelValues[j];
}
break;
break;
}
Profiler.EndSample();
}

/// <param name="shouldTransformPoint">Bool weather to transform the point to the current transform</param>
protected Vector3 CellToPoint(int cell, bool shouldTransformPoint = true)
{
float x = (cell % GridNumSideZ - OffsetGridNumSide) * CellScaleX;
float z = (cell / GridNumSideZ - OffsetGridNumSide) * CellScaleZ - DiffNumSideZX;
var x = (cell % GridNumSideZ - OffsetGridNumSide) * CellScaleX;
var z = (cell / GridNumSideZ - OffsetGridNumSide) * CellScaleZ - DiffNumSideZX;
if (shouldTransformPoint)
return transform.TransformPoint(new Vector3(x, 0, z));

/// <param name="globalPoint">The 3D point in global space</param>
public int PointToCell(Vector3 globalPoint)
{
Vector3 point = transform.InverseTransformPoint(globalPoint);
var point = transform.InverseTransformPoint(globalPoint);
float x = point.x + HalfOfGridX;
float z = point.z + HalfOfGridZ;
var x = point.x + HalfOfGridX;
var z = point.z + HalfOfGridZ;
int _x = (int)Mathf.Floor(x * PointToCellScalingX);
int _z = (int)Mathf.Floor(z * PointToCellScalingZ);
var _x = (int)Mathf.Floor(x * PointToCellScalingX);
var _z = (int)Mathf.Floor(z * PointToCellScalingZ);
/// <summary>Copies the data from one cell to another</summary>
/// <param name="fromCellID">index of the cell to copy from</param>
/// <param name="toCellID">index of the cell to copy into</param>
protected void CopyCellData(int fromCellID, int toCellID)
{
Array.Copy(m_PerceptionBuffer,
fromCellID * ObservationPerCell,
m_PerceptionBuffer,
toCellID * ObservationPerCell,
ObservationPerCell);
if (ShowGizmos)
CellActivity[toCellID] = CellActivity[fromCellID];
}
// /// <summary>Copies the data from one cell to another</summary>
// /// <param name="fromCellID">index of the cell to copy from</param>
// /// <param name="toCellID">index of the cell to copy into</param>
// protected void CopyCellData(int fromCellID, int toCellID)
// {
// Array.Copy(m_PerceptionBuffer,
// fromCellID * ObservationPerCell,
// m_PerceptionBuffer,
// toCellID * ObservationPerCell,
// ObservationPerCell);
// if (ShowGizmos)
// CellActivity[toCellID] = CellActivity[fromCellID];
// }
void OnDrawGizmos()
{

Start();
Perceive();
UpdateBufferFromJob();
var scale = new Vector3(CellScaleX, 1, CellScaleZ);
var scale = new Vector3(CellScaleX, CellScaleY, CellScaleZ);
var offset = new Vector3(0, GizmoYOffset, 0);
var oldGizmoMatrix = Gizmos.matrix;
for (var i = 0; i < NumCells; i++)

}
Gizmos.matrix = oldGizmoMatrix * cubeTransform;
Gizmos.color = CellActivity[i];
Gizmos.DrawCube(Vector3.zero, Vector3.one);
Gizmos.DrawWireCube(Vector3.zero, Vector3.one);
}
Gizmos.matrix = oldGizmoMatrix;

/// <inheritdoc/>
void ISensor.Update()
{
using (TimerStack.Instance.Scoped("GridSensor.Update"))
Profiler.BeginSample("GridSensor.Update");
Profiler.EndSample();
}
/// <summary>Gets the observation shape</summary>

/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
using (TimerStack.Instance.Scoped("GridSensor.WriteToTensor"))
UpdateBufferFromJob();
var index = 0;
Profiler.BeginSample("GridSensor.WriteToTensor");
int index = 0;
for (var h = GridNumSideZ - 1; h >= 0; h--) // height
{
for (var w = 0; w < GridNumSideX; w++) // width

}
}
}
return index;
Profiler.EndSample();
return index;
}
internal void UpdateBufferFromJob()
{
Profiler.BeginSample("UpdateBufferFromJob");
{
m_boxcastJobHandle.Complete();
for (var cellIndex = 0; cellIndex < NumCells; cellIndex++)
{
var c = m_Hits[cellIndex].collider;
if (ReferenceEquals(c, null))
{
continue;
}
for (var i = 0; i < DetectableObjects.Length; i++)
{
if (c.CompareTag(DetectableObjects[i]))
{
LoadObjectData(c.gameObject, cellIndex, i, 0);
break;
}
}
}
}
Profiler.EndSample();
}
}
}

3
com.unity.ml-agents.extensions/Runtime/Unity.ML-Agents.Extensions.asmdef


"references": [
"Unity.Barracuda",
"Unity.ML-Agents",
"Unity.ML-Agents.Extensions.Input"
"Unity.ML-Agents.Extensions.Input",
"Unity.Burst"
]
}

37
com.unity.ml-agents.extensions/Tests/Editor/Sensors/CountingGridSensorPerceiveTests.cs


using System.Collections;
using NUnit.Framework;
using Unity.Collections;
using UnityEngine;
using UnityEngine.TestTools;
using Unity.MLAgents.Extensions.Sensors;

public GameObject CreateBlock(Vector3 postion, string tag, string name)
{
GameObject boxGo = new GameObject(name);
var boxGo = new GameObject(name);
boxGo.tag = tag;
boxGo.transform.position = postion;
boxGo.AddComponent<BoxCollider>();

yield return null;
float[] output = gridSensor.Perceive();
var output = gridSensor.Perceive(); gridSensor.UpdateBufferFromJob();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 1f }, 4);
float[] expectedDefault = new float[] { 0 };
var subarrayIndicies = new int[] { 77, 78, 87, 88 };
var expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 1f }, 4);
var expectedDefault = new float[] { 0 };
GridObsTestUtils.AssertSubarraysAtIndex(output, subarrayIndicies, expectedSubarrays, expectedDefault);
}

yield return null;
float[] output = gridSensor.Perceive();
var output = gridSensor.Perceive(); gridSensor.UpdateBufferFromJob();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 1f }, 4);
float[] expectedDefault = new float[] { 0f };
var subarrayIndicies = new int[] { 77, 78, 87, 88 };
var expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 1f }, 4);
var expectedDefault = new float[] { 0f };
GridObsTestUtils.AssertSubarraysAtIndex(output, subarrayIndicies, expectedSubarrays, expectedDefault);
}

yield return null;
float[] output = gridSensor.Perceive();
var output = gridSensor.Perceive();
gridSensor.UpdateBufferFromJob();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { .5f }, 4);
float[] expectedDefault = new float[] { 0 };
var subarrayIndicies = new int[] { 77, 78, 87, 88 };
var expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { .25f }, 4);
var expectedDefault = new float[] { 0 };
GridObsTestUtils.AssertSubarraysAtIndex(output, subarrayIndicies, expectedSubarrays, expectedDefault);
}

yield return null;
float[] output = gridSensor.Perceive();
var output = gridSensor.Perceive();
gridSensor.UpdateBufferFromJob();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { .5f, 1 }, 4);
float[] expectedDefault = new float[] { 0, 0 };
var subarrayIndicies = new int[] { 77, 78, 87, 88 };
var expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 0f, 1 }, 4);
var expectedDefault = new float[] { 0, 0 };
GridObsTestUtils.AssertSubarraysAtIndex(output, subarrayIndicies, expectedSubarrays, expectedDefault);
}
}

25
com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridObservationPerceiveTests.cs


using System.Collections;
using NUnit.Framework;
using Unity.Collections;
using UnityEngine;
using UnityEngine.TestTools;
using Unity.MLAgents.Extensions.Sensors;

yield return null;
float[] output = gridSensor.Perceive();
var output = gridSensor.Perceive(); gridSensor.UpdateBufferFromJob();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new float[] { 0, 0, 1 }, 4);
float[] expectedDefault = new float[] { 1, 0, 0 };
var subarrayIndicies = new int[] { 77, 78, 87, 88 };
var expectedSubarrays = GridObsTestUtils.DuplicateArray(new float[] { 0, 0, 1 }, 4);
var expectedDefault = new float[] { 1, 0, 0 };
GridObsTestUtils.AssertSubarraysAtIndex(output, subarrayIndicies, expectedSubarrays, expectedDefault);
}

yield return null;
float[] output = gridSensor.Perceive();
var output = gridSensor.Perceive(); gridSensor.UpdateBufferFromJob();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 2f / 3f }, 4);
float[] expectedDefault = new float[] { 0f };
var subarrayIndicies = new int[] { 77, 78, 87, 88 };
var expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 2f / 3f }, 4);
var expectedDefault = new float[] { 0f };
GridObsTestUtils.AssertSubarraysAtIndex(output, subarrayIndicies, expectedSubarrays, expectedDefault);
}

yield return null;
float[] output = gridSensor.Perceive();
var output = gridSensor.Perceive(); gridSensor.UpdateBufferFromJob();
int[] subarrayIndicies = new int[] { 77, 78, 87, 88 };
float[][] expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 1f / 3f }, 4);
float[] expectedDefault = new float[] { 0f };
var subarrayIndicies = new int[] { 77, 78, 87, 88 };
var expectedSubarrays = GridObsTestUtils.DuplicateArray(new[] { 1f / 3f }, 4);
var expectedDefault = new float[] { 0f };
GridObsTestUtils.AssertSubarraysAtIndex(output, subarrayIndicies, expectedSubarrays, expectedDefault);
}
}

3
com.unity.ml-agents.extensions/Tests/Editor/Sensors/GridSensorTestUtils.cs


using NUnit.Framework;
using System;
using System.Linq;
using Unity.Collections;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{

/// expecedDefaultArray = new float[] {0, 0}
/// )
/// </example>
public static void AssertSubarraysAtIndex(float[] total, int[] indicies, float[][] expectedArrays, float[] expectedDefaultArray)
public static void AssertSubarraysAtIndex(NativeArray<float> total, int[] indicies, float[][] expectedArrays, float[] expectedDefaultArray)
{
int totalIndex = 0;
int subIndex = 0;

3
com.unity.ml-agents.extensions/package.json


"unity": "2018.4",
"description": "A source-only package for new features based on ML-Agents",
"dependencies": {
"com.unity.ml-agents": "1.9.0-preview"
"com.unity.ml-agents": "1.9.0-preview",
"com.unity.burst": "1.4.6"
}
}
正在加载...
取消
保存