using System;
using Unity.Burst;
using Unity.Collections;
using Unity.Collections.LowLevel.Unsafe;
using Unity.Jobs;
using Unity.Mathematics;
using Unity.Profiling;
namespace UnityEngine.Perception.GroundTruth
{
///
/// A CPU-based pass which computes bounding box and pixel counts per-object from instance segmentation images
///
public class RenderedObjectInfoGenerator : IGroundTruthGenerator, IDisposable
{
static ProfilerMarker s_LabelJobs = new ProfilerMarker("Label Jobs");
static ProfilerMarker s_LabelMerge = new ProfilerMarker("Label Merge");
const int k_StartingObjectCount = 1 << 8;
struct Object1DSpan
{
public int instanceId;
public int row;
public int left;
public int right;
}
[BurstCompile]
struct ComputeHistogramPerRowJob : IJob
{
[ReadOnly]
public NativeSlice segmentationImageData;
public int width;
public int rows;
public int rowStart;
[NativeDisableContainerSafetyRestriction]
public NativeList boundingBoxes;
public void Execute()
{
for (var row = 0; row < rows; row++)
{
var rowSlice = new NativeSlice(segmentationImageData, width * row, width);
var currentBB = new Object1DSpan
{
instanceId = -1,
row = row + rowStart
};
for (var i = 0; i < rowSlice.Length; i++)
{
var value = rowSlice[i];
if (value != currentBB.instanceId)
{
if (currentBB.instanceId > 0)
{
//save off currentBB
currentBB.right = i - 1;
boundingBoxes.Add(currentBB);
}
currentBB = new Object1DSpan
{
instanceId = (int)value,
left = i,
row = row + rowStart
};
}
}
if (currentBB.instanceId > 0)
{
//save off currentBB
currentBB.right = width - 1;
boundingBoxes.Add(currentBB);
}
}
}
}
NativeList m_InstanceIdToClassIdLookup;
LabelingConfiguration m_LabelingConfiguration;
// ReSharper disable once InvalidXmlDocComment
///
/// Create a new CpuRenderedObjectInfoPass with the given LabelingConfiguration.
///
/// The LabelingConfiguration to use to determine labelId. Should match the
/// one used by the generating the input image. See
public RenderedObjectInfoGenerator(LabelingConfiguration labelingConfiguration)
{
m_LabelingConfiguration = labelingConfiguration;
m_InstanceIdToClassIdLookup = new NativeList(k_StartingObjectCount, Allocator.Persistent);
}
///
public void SetupMaterialProperties(MaterialPropertyBlock mpb, MeshRenderer meshRenderer, Labeling labeling, uint instanceId)
{
if (m_LabelingConfiguration.TryGetMatchingConfigurationIndex(labeling, out var index))
{
if (m_InstanceIdToClassIdLookup.Length <= instanceId)
{
m_InstanceIdToClassIdLookup.Resize((int)instanceId + 1, NativeArrayOptions.ClearMemory);
}
m_InstanceIdToClassIdLookup[(int)instanceId] = index;
}
}
// ReSharper disable once InvalidXmlDocComment
///
/// Compute RenderedObjectInfo for each visible object in the given instance segmentation image.
/// InstanceSegmentationRawData should be the raw data from a texture filled by or
/// using the same LabelingConfiguration that was passed into this object.
///
///
///
///
///
///
///
public void Compute(NativeArray instanceSegmentationRawData, int stride, BoundingBoxOrigin boundingBoxOrigin, out NativeArray boundingBoxes, out NativeArray classCounts, Allocator allocator)
{
const int jobCount = 24;
var height = instanceSegmentationRawData.Length / stride;
//special math to round up
var rowsPerJob = height / jobCount;
var rowRemainder = height % jobCount;
var handles = new NativeArray(jobCount, Allocator.Temp);
var jobBoundingBoxLists = new NativeList[jobCount];
using (s_LabelJobs.Auto())
{
for (int row = 0, jobIndex = 0; row < height; row += rowsPerJob, jobIndex++)
{
jobBoundingBoxLists[jobIndex] = new NativeList(10, Allocator.TempJob);
var rowsThisJob = math.min(height - row, rowsPerJob);
if (jobIndex < rowRemainder)
rowsThisJob++;
handles[jobIndex] = new ComputeHistogramPerRowJob
{
segmentationImageData = new NativeSlice(instanceSegmentationRawData, row * stride, stride * rowsThisJob),
width = stride,
rowStart = row,
rows = rowsThisJob,
boundingBoxes = jobBoundingBoxLists[jobIndex]
}.Schedule();
if (jobIndex < rowRemainder)
row++;
}
JobHandle.CompleteAll(handles);
}
classCounts = new NativeArray(m_LabelingConfiguration.LabelingConfigurations.Count, allocator);
var boundingBoxMap = new NativeHashMap(100, Allocator.Temp);
using (s_LabelMerge.Auto())
{
foreach (var boundingBoxList in jobBoundingBoxLists)
{
if (!boundingBoxList.IsCreated)
continue;
foreach (var info1D in boundingBoxList)
{
var objectInfo = new RenderedObjectInfo
{
boundingBox = new Rect(info1D.left, info1D.row, info1D.right - info1D.left + 1, 1),
instanceId = info1D.instanceId,
pixelCount = info1D.right - info1D.left + 1
};
if (boundingBoxMap.TryGetValue(info1D.instanceId, out var info))
{
objectInfo.boundingBox = Rect.MinMaxRect(
math.min(info.boundingBox.xMin, objectInfo.boundingBox.xMin),
math.min(info.boundingBox.yMin, objectInfo.boundingBox.yMin),
math.max(info.boundingBox.xMax, objectInfo.boundingBox.xMax),
math.max(info.boundingBox.yMax, objectInfo.boundingBox.yMax));
objectInfo.pixelCount += info.pixelCount;
}
boundingBoxMap[info1D.instanceId] = objectInfo;
}
}
var keyValueArrays = boundingBoxMap.GetKeyValueArrays(Allocator.Temp);
boundingBoxes = new NativeArray(keyValueArrays.Keys.Length, allocator);
for (var i = 0; i < keyValueArrays.Keys.Length; i++)
{
var instanceId = keyValueArrays.Keys[i];
if (m_InstanceIdToClassIdLookup.Length <= instanceId)
continue;
var classId = m_InstanceIdToClassIdLookup[instanceId];
classCounts[classId]++;
var renderedObjectInfo = keyValueArrays.Values[i];
var boundingBox = renderedObjectInfo.boundingBox;
if (boundingBoxOrigin == BoundingBoxOrigin.TopLeft)
{
var y = height - boundingBox.yMax;
boundingBox = new Rect(boundingBox.x, y, boundingBox.width, boundingBox.height);
}
boundingBoxes[i] = new RenderedObjectInfo
{
instanceId = instanceId,
labelId = classId,
boundingBox = boundingBox,
pixelCount = renderedObjectInfo.pixelCount
};
}
keyValueArrays.Dispose();
}
boundingBoxMap.Dispose();
foreach (var rowBoundingBox in jobBoundingBoxLists)
{
if (rowBoundingBox.IsCreated)
rowBoundingBox.Dispose();
}
handles.Dispose();
}
///
/// Attempts to find the label id for the given instance id using the LabelingConfiguration passed into the constructor.
///
/// The instanceId of the object for which the labelId should be found
/// The labelId of the object. -1 if not found
/// True if a labelId is found for the given instanceId.
public bool TryGetLabelIdFromInstanceId(int instanceId, out int labelId)
{
labelId = -1;
if (m_InstanceIdToClassIdLookup.Length <= instanceId)
return false;
labelId = m_InstanceIdToClassIdLookup[instanceId];
return true;
}
///
public void Dispose()
{
m_InstanceIdToClassIdLookup.Dispose();
}
}
}