CellActivity;
public Color DebugDefault;
public void Execute(int index)
{
CellActivity[index] = DebugDefault;
}
}
///
/// Clears the perception buffer before loading in new data. If the gridDepthType is ChannelHot, then it initializes the
/// Reset() also reinits the cell activity array (for debug)
///
public void ClearPerceptionBuffer()
{
Profiler.BeginSample("ClearPerceptionBuffer");
{
if (gridDepthType == GridDepthType.ChannelHot)
{
// Copy the default value to the array
for (var i = 0; i < m_PerceptionBuffer.Length; i++)
{
m_PerceptionBuffer[i] = m_ChannelHotDefaultPerceptionBuffer[i % ObservationPerCell];
}
}
else
{
for (var i = 0; i < m_PerceptionBuffer.Length; i++)
{
m_PerceptionBuffer[i] = 0;
}
}
if (ShowGizmos)
{
// Assign the default color to the cell activities
for (var i = 0; i < NumCells; i++)
{
CellActivity[i] = DebugDefaultColor;
}
}
}
Profiler.EndSample();
}
void OnDisable()
{
}
void OnDestroy()
{
DisposeNativeArray(m_Hits);
DisposeNativeArray(m_HitIndexes);
DisposeNativeArray(m_BoxcastCommands);
DisposeNativeArray(CellActivity);
DisposeNativeArray(CellPoints);
DisposeNativeArray(m_ChannelHotDefaultPerceptionBuffer);
DisposeNativeArray(ChannelOffsets);
}
static void DisposeNativeArray(NativeArray | array)
where TD : struct
{
if (array.IsCreated)
{
array.Dispose();
}
}
/// Gets the shape of the grid observation
/// integer array shape of the grid observation
public int[] GetFloatObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}
///
public string GetName()
{
return Name;
}
///
public virtual SensorCompressionType GetCompressionType()
{
return CompressionType;
}
///
public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.GridSensor;
}
///
/// GetCompressedObservation - Calls Perceive then puts the data stored on the perception buffer
/// onto the m_perceptionTexture2D to be converted to a byte array and returned
///
/// byte[] containing the compressed observation of the grid observation
public byte[] GetCompressedObservation()
{
UpdateBufferFromJob();
var allBytes = new List();
Profiler.BeginSample("GridSensor.GetCompressedObservation");
{
for (var i = 0; i < NumImages - 1; i++)
{
ChannelsToTexture(3 * i, 3);
allBytes.AddRange(m_perceptionTexture2D.EncodeToPNG());
}
ChannelsToTexture(3 * (NumImages - 1), NumChannelsOnLastImage);
allBytes.AddRange(m_perceptionTexture2D.EncodeToPNG());
}
Profiler.EndSample();
return allBytes.ToArray();
}
///
/// ChannelsToTexture - Takes the channel index and the numChannelsToAdd.
/// For each cell and for each channel to add, sets it to a value of the color specified for that cell.
/// All colors are then set to the perceptionTexture via SetPixels.
/// m_perceptionTexture2D can then be read as an image as it now contains all of the information that was
/// stored in the channels
///
///
///
protected void ChannelsToTexture(int channelIndex, int numChannelsToAdd)
{
for (var i = 0; i < NumCells; i++)
{
for (var j = 0; j < numChannelsToAdd; j++)
{
m_PerceptionColors[i][j] = m_PerceptionBuffer[i * ObservationPerCell + channelIndex + j];
}
}
m_perceptionTexture2D.SetPixels(m_PerceptionColors);
}
[BurstCompile(CompileSynchronously = true)]
struct CreateBoxcastBatch : IJobParallelFor
{
public NativeArray Commands;
public Vector3 halfScale;
public int ObserveMask;
[ReadOnly]
public NativeArray CellPoints;
public Vector3 Pos;
public void Execute(int index)
{
var cellCenter = Pos + CellPoints[index];
var rotation = Quaternion.identity;
Commands[index] = new BoxcastCommand(new Vector3(cellCenter.x, cellCenter.y + 2.0f, cellCenter.z),
halfScale,
rotation,
Vector3.down,
10.0f,
ObserveMask);
}
}
[BurstCompile(CompileSynchronously = true)]
struct CreateBoxcastRotateBatch : IJobParallelFor
{
public NativeArray Commands;
public Vector3 halfScale;
public int ObserveMask;
[ReadOnly]
public NativeArray CellPoints;
[ReadOnly]
public Matrix4x4 Mat;
public Quaternion Rotation;
public void Execute(int index)
{
var cellCenter = Mat.MultiplyPoint(CellPoints[index]);
var rotation = Rotation;
Commands[index] = new BoxcastCommand(new Vector3(cellCenter.x, cellCenter.y + 5.0f, cellCenter.z),
halfScale,
rotation,
Vector3.down,
10.0f,
ObserveMask);
}
}
// ReSharper disable Unity.PerformanceAnalysis
///
/// Perceive - Clears the buffers, calls overlap box on the actual cell (the actual perception part)
/// for all found colliders, LoadObjectData is called
/// at the end, Perceive returns the float array of the perceptions
///
/// A float[] containing all of the information collected from the gridsensor
public NativeArray Perceive()
{
if (!m_PerceptionBuffer.IsCreated || !m_Hits.IsCreated || !m_BoxcastCommands.IsCreated)
{
return new NativeArray();
}
Profiler.BeginSample("GridSensor.Perceive");
{
m_boxcastJobHandle.Complete();
JobHandle clearHandle;
if (gridDepthType == GridDepthType.ChannelHot)
{
var clearJob = new ClearBufferOneHotJob
{
PerceptionBuf = m_PerceptionBuffer,
ChannelHotBuf = m_ChannelHotDefaultPerceptionBuffer,
ObservationPerCell = ObservationPerCell
};
clearHandle = clearJob.Schedule(NumberOfObservations, NumberOfObservations / 12);
}
else
{
var clearJob = new ClearBufferChannelJob
{
PerceptionBuf = m_PerceptionBuffer
};
clearHandle = clearJob.Schedule(NumberOfObservations, NumberOfObservations / 12);
}
if (ShowGizmos)
{
var gizmoJob = new ClearCellActivityJob
{
CellActivity = CellActivity,
DebugDefault = DebugDefaultColor
};
var gizmoHandle = gizmoJob.Schedule(NumCells, NumCells / 12);
clearHandle = JobHandle.CombineDependencies(clearHandle, gizmoHandle);
}
var t = transform;
var halfCellScale = new Vector3(CellScaleX / 2f, CellScaleY / 2f, CellScaleZ / 2f);
JobHandle createBoxHandle;
if (RotateToAgent)
{
var createBoxcasts = new CreateBoxcastRotateBatch
{
Commands = m_BoxcastCommands,
halfScale = halfCellScale,
Mat = t.localToWorldMatrix,
CellPoints = CellPoints,
ObserveMask = ObserveMask,
Rotation = t.rotation
};
createBoxHandle = createBoxcasts.Schedule(NumCells, NumCells / 12, clearHandle);
}
else
{
var createBoxcasts = new CreateBoxcastBatch
{
Commands = m_BoxcastCommands,
halfScale = halfCellScale,
Pos = t.position,
CellPoints = CellPoints,
ObserveMask = ObserveMask
};
createBoxHandle = createBoxcasts.Schedule(NumCells, NumCells / 12, clearHandle);
}
m_boxcastJobHandle = BoxcastCommand.ScheduleBatch(m_BoxcastCommands, m_Hits, NumCells / 12, createBoxHandle);
}
Profiler.EndSample();
return m_PerceptionBuffer;
}
///
/// Parses the array of colliders found within a cell. Finds the closest gameobject to the agent root reference within the cell
///
/// Array of the colliders found within the cell
/// Number of colliders found.
/// The index of the cell
/// The center position of the cell
protected virtual void ParseColliders(Collider[] foundColliders, int numFound, int cellIndex, Vector3 cellCenter)
{
// Profiler.BeginSample("GridSensor.ParseColliders");
// GameObject closestColliderGo = null;
// var minDistanceSquared = float.MaxValue;
//
// var detectableIndex = -1;
// for (var i = 0; i < numFound; i++)
// {
// if (ReferenceEquals(foundColliders[i], null))
// {
// continue;
// }
// var currentColliderGo = foundColliders[i].gameObject;
//
// // Continue if the current collider go is the root reference
// if (ReferenceEquals(currentColliderGo, rootReference))
// continue;
//
// var closestColliderPoint = foundColliders[i].ClosestPointOnBounds(cellCenter);
// var currentDistanceSquared = (closestColliderPoint - rootReference.transform.position).sqrMagnitude;
// Profiler.EndSample();
//
// // Checks if our colliders contain a detectable object
// if (m_DetectableObjectToIndex.TryGetValue(currentColliderGo.tag, out detectableIndex) && currentDistanceSquared < minDistanceSquared)
// {
// }
// for (var ii = 0; ii < DetectableObjects.Length; ii++)
// {
// if (currentColliderGo.CompareTag(DetectableObjects[ii]))
// {
// minDistanceSquared = currentDistanceSquared;
// closestColliderGo = currentColliderGo;
// break;
// }
// }
// }
//
// if (!ReferenceEquals(closestColliderGo, null))
// LoadObjectData(closestColliderGo, cellIndex, detectableIndex, (float)Math.Sqrt(minDistanceSquared) * InverseSphereRadius);
// Profiler.EndSample();
}
///
/// GetObjectData - returns an array of values that represent the game object
/// This is one of the few methods that one may need to override to get their required functionality
/// For instance, if one wants specific information about the current gameobject, they can use this method
/// to extract it and then return it in an array format.
///
///
/// A float[] containing the data that holds the representative information of the passed in gameObject
///
/// The game object that was found colliding with a certain cell
/// The index of the type (tag) of the gameObject.
/// (e.g., if this GameObject had the 3rd tag out of 4, type_index would be 2.0f)
/// A float between 0 and 1 describing the ratio of
/// the distance currentColliderGo is compared to the edge of the gridsensor
///
/// Here is an example of extenind GetObjectData to include information about a potential Rigidbody:
///
/// protected override float[] GetObjectData(GameObject currentColliderGo,
/// float type_index, float normalized_distance)
/// {
/// float[] channelValues = new float[ChannelDepth.Length]; // ChannelDepth.Length = 4 in this example
/// channelValues[0] = type_index;
/// Rigidbody goRb = currentColliderGo.GetComponent<Rigidbody>();
/// if (goRb != null)
/// {
/// channelValues[1] = goRb.velocity.x;
/// channelValues[2] = goRb.velocity.y;
/// channelValues[3] = goRb.velocity.z;
/// }
/// return channelValues;
/// }
///
///
protected virtual float[] GetObjectData(GameObject currentColliderGo, float typeIndex, float normalizedDistance)
{
Array.Clear(m_ChannelBuffer, 0, m_ChannelBuffer.Length);
m_ChannelBuffer[0] = typeIndex;
return m_ChannelBuffer;
}
///
/// Runs basic validation assertions to check that the values can be normalized
///
/// The values to be validated
/// The gameobject used for better error messages
protected virtual void ValidateValues(float[] channelValues, GameObject currentColliderGo)
{
for (var j = 0; j < channelValues.Length; j++)
{
if (channelValues[j] < 0)
throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be non-negative, was " + channelValues[j]);
if (channelValues[j] > ChannelDepth[j])
throw new UnityAgentsException("Expected ChannelValue[" + j + "] for " + currentColliderGo.name + " to be less than ChannelDepth[" + j + "] (" + ChannelDepth[j] + "), was " + channelValues[j]);
}
}
///
/// LoadObjectData - If the GameObject matches a tag, GetObjectData is called to extract the data from the GameObject
/// then the data is transformed based on the GridDepthType of the gridsensor.
/// Further documetation on the GridDepthType can be found below
///
/// The game object that was found colliding with a certain cell
/// The index of the current cell
/// Index into the DetectableObjects array.
/// A float between 0 and 1 describing the ratio of
/// the distance currentColliderGo is compared to the edge of the gridsensor
protected virtual void LoadObjectData(GameObject currentColliderGo, int cellIndex, int detectableIndex, float normalizedDistance)
{
Profiler.BeginSample("GridSensor.LoadObjectData");
var offset = cellIndex * ObservationPerCell;
for (var ii = 0; ii < ObservationPerCell; ii++)
{
m_PerceptionBuffer[offset + ii] = 0f;
}
// TODO: Create the array already then set the values using "out" in GetObjectData
// Using i+1 as the type index as "0" represents "empty"
var channelValues = GetObjectData(currentColliderGo, (float)detectableIndex + 1, normalizedDistance);
ValidateValues(channelValues, currentColliderGo);
if (ShowGizmos)
{
var debugRayColor = Color.white;
if (DebugColors.Length > 0)
{
debugRayColor = DebugColors[detectableIndex];
}
CellActivity[cellIndex] = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
}
switch (gridDepthType)
{
case GridDepthType.Channel:
{
// The observations are "channel based" so each grid is WxHxC where C is the number of channels
// This typically means that each channel value is normalized between 0 and 1
// If channelDepth is 1, the value is assumed normalized, else the value is normalized by the channelDepth
// The channels are then stored consecutively in PerceptionBuffer.
// NOTE: This is the only grid type that uses floating point values
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams:
// channelValues = {2, 1}
// ObservationPerCell = channelValues.Length
// channelValues = {2f/5f, 1f/3f} = {.4, .33..}
// Array.Copy(channelValues, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (var j = 0; j < channelValues.Length; j++)
{
channelValues[j] /= ChannelDepth[j];
}
for (var ii = 0; ii < ObservationPerCell; ii++)
{
m_PerceptionBuffer[offset + ii] = channelValues[ii];
}
break;
}
case GridDepthType.ChannelHot:
{
// The observations are "channel hot" so each grid is WxHxD where D is the sum of all of the channel depths
// The opposite of the "channel based" case, the channel values are represented as one hot vector per channel and then concatenated together
// Thus channelDepth is assumed to be greater than 1.
// For example, if a cell contains the 3rd type of 5 possible on the 2nd team of 3 possible teams,
// channelValues = {2, 1}
// channelOffsets = {5, 3}
// ObservationPerCell = 5 + 3 = 8
// channelHotVals = {0, 0, 1, 0, 0, 0, 1, 0}
// Array.Copy(channelHotVals, 0, PerceptionBuffer, cell_id*ObservationPerCell, ObservationPerCell);
for (var j = 0; j < channelValues.Length; j++)
{
if (ChannelDepth[j] > 1)
{
m_PerceptionBuffer[offset + (int)channelValues[j] + ChannelOffsets[j]] = 1f;
}
else
{
m_PerceptionBuffer[offset + ChannelOffsets[j]] = channelValues[j];
}
}
break;
}
}
Profiler.EndSample();
}
/// Converts the index of the cell to the 3D point (y is zero)
/// Vector3 of the position of the center of the cell
/// The index of the cell
/// Bool weather to transform the point to the current transform
protected Vector3 CellToPoint(int cell, bool shouldTransformPoint = true)
{
var x = (cell % GridNumSideZ - OffsetGridNumSide) * CellScaleX;
var z = (cell / GridNumSideZ - OffsetGridNumSide) * CellScaleZ - DiffNumSideZX;
if (shouldTransformPoint)
return transform.TransformPoint(new Vector3(x, 0, z));
return new Vector3(x, 0, z);
}
/// Finds the cell in which the given global point falls
///
/// The index of the cell in which the global point falls or -1 if the point does not fall into a cell
///
/// The 3D point in global space
public int PointToCell(Vector3 globalPoint)
{
var point = transform.InverseTransformPoint(globalPoint);
if (point.x < -HalfOfGridX || point.x > HalfOfGridX || point.z < -HalfOfGridZ || point.z > HalfOfGridZ)
return -1;
var x = point.x + HalfOfGridX;
var z = point.z + HalfOfGridZ;
var _x = (int)Mathf.Floor(x * PointToCellScalingX);
var _z = (int)Mathf.Floor(z * PointToCellScalingZ);
return GridNumSideX * _z + _x;
}
// /// Copies the data from one cell to another
// /// index of the cell to copy from
// /// index of the cell to copy into
// protected void CopyCellData(int fromCellID, int toCellID)
// {
// Array.Copy(m_PerceptionBuffer,
// fromCellID * ObservationPerCell,
// m_PerceptionBuffer,
// toCellID * ObservationPerCell,
// ObservationPerCell);
// if (ShowGizmos)
// CellActivity[toCellID] = CellActivity[fromCellID];
// }
void OnDrawGizmos()
{
if (ShowGizmos)
{
if (Application.isEditor && !Application.isPlaying)
Start();
Perceive();
UpdateBufferFromJob();
var scale = new Vector3(CellScaleX, CellScaleY, CellScaleZ);
var offset = new Vector3(0, 5.0f, 0);
var oldGizmoMatrix = Gizmos.matrix;
for (var i = 0; i < NumCells; i++)
{
Matrix4x4 cubeTransform;
if (RotateToAgent)
{
cubeTransform = Matrix4x4.TRS(CellToPoint(i) + offset, transform.rotation, scale);
}
else
{
cubeTransform = Matrix4x4.TRS(CellToPoint(i, false) + transform.position + offset, Quaternion.identity, scale);
}
Gizmos.matrix = oldGizmoMatrix * cubeTransform;
Gizmos.color = CellActivity[i];
Gizmos.DrawCube(Vector3.zero, Vector3.one);
}
Gizmos.matrix = oldGizmoMatrix;
if (Application.isEditor && !Application.isPlaying)
DestroyImmediate(m_perceptionTexture2D);
}
}
///
void ISensor.Update()
{
Profiler.BeginSample("GridSensor.Update");
{
Perceive();
}
Profiler.EndSample();
}
/// Gets the observation shape
/// int[] of the observation shape
public override int[] GetObservationShape()
{
m_Shape = new[] { GridNumSideX, GridNumSideZ, ObservationPerCell };
return m_Shape;
}
///
public int Write(ObservationWriter writer)
{
UpdateBufferFromJob();
var index = 0;
Profiler.BeginSample("GridSensor.WriteToTensor");
{
for (var h = GridNumSideZ - 1; h >= 0; h--) // height
{
for (var w = 0; w < GridNumSideX; w++) // width
{
for (var d = 0; d < ObservationPerCell; d++) // depth
{
writer[h, w, d] = m_PerceptionBuffer[index];
index++;
}
}
}
Profiler.EndSample();
}
return index;
}
internal void UpdateBufferFromJob()
{
Profiler.BeginSample("UpdateBufferFromJob");
{
m_boxcastJobHandle.Complete();
for (var cellIndex = 0; cellIndex < NumCells; cellIndex++)
{
var c = m_Hits[cellIndex].collider;
if (ReferenceEquals(c, null))
{
continue;
}
for (var i = 0; i < DetectableObjects.Length; i++)
{
if (c.CompareTag(DetectableObjects[i]))
{
LoadObjectData(c.gameObject, cellIndex, i, 0);
break;
}
}
}
}
Profiler.EndSample();
}
}
}
|