using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using Unity.Collections;
using Unity.Profiling;
using UnityEngine.Perception.GroundTruth.DataModel;
using UnityEngine.Rendering;
namespace UnityEngine.Perception.GroundTruth
{
///
/// Labeler which produces object counts for each label in the associated each frame.
///
[Serializable]
public sealed class ObjectCountLabeler : CameraLabeler
{
///
/// The object count metric records how many of a particular object are
/// present in a capture.
///
[Serializable]
public class ObjectCountMetric : Metric
{
public struct Entry
{
///
/// The label of the category
///
public string labelName;
///
/// The number of instances for a particular category.
///
public int count;
}
///
/// The object counts
///
public IEnumerable objectCounts;
}
static readonly string k_Id = "ObjectCount";
static readonly string k_Description = "Produces object counts for each label defined in this labeler's associated label configuration.";
///
public override string description
{
get => k_Description;
protected set {}
}
///
/// The which associates objects with labels.
///
public IdLabelConfig labelConfig
{
get => m_LabelConfig;
set => m_LabelConfig = value;
}
///
/// Fired when the object counts are computed for a frame.
///
public event Action,IReadOnlyList> ObjectCountsComputed;
[SerializeField]
IdLabelConfig m_LabelConfig;
static ProfilerMarker s_ClassCountCallback = new ProfilerMarker("OnClassLabelsReceived");
ObjectCountMetric.Entry[] m_ClassCountValues;
Dictionary m_AsyncMetrics;
MetricDefinition m_Definition = new MetricDefinition(k_Id, k_Description);
///
/// Creates a new ObjectCountLabeler. This constructor should only be used by serialization. For creation from
/// user code, use .
///
public ObjectCountLabeler()
{
}
///
/// Creates a new ObjectCountLabeler with the given .
///
/// The label config for resolving the label for each object.
public ObjectCountLabeler(IdLabelConfig labelConfig)
{
if (labelConfig == null)
throw new ArgumentNullException(nameof(labelConfig));
m_LabelConfig = labelConfig;
}
///
protected override bool supportsVisualization => true;
///
protected override void Setup()
{
if (labelConfig == null)
throw new InvalidOperationException("The ObjectCountLabeler idLabelConfig field must be assigned");
m_AsyncMetrics = new Dictionary();
perceptionCamera.RenderedObjectInfosCalculated += (frameCount, objectInfo) =>
{
var objectCounts = ComputeObjectCounts(objectInfo);
ObjectCountsComputed?.Invoke(frameCount, objectCounts, labelConfig.labelEntries);
ProduceObjectCountMetric(objectCounts, m_LabelConfig.labelEntries, frameCount);
};
m_Definition = new MetricDefinition
{
id = k_Id,
description = k_Description
};
DatasetCapture.Instance.RegisterMetric(m_Definition);
visualizationEnabled = supportsVisualization;
}
///
protected override void OnBeginRendering(ScriptableRenderContext scriptableRenderContext)
{
#if true
m_AsyncMetrics[Time.frameCount] = perceptionCamera.SensorHandle.ReportMetricAsync(m_Definition);
#endif
}
NativeArray ComputeObjectCounts(NativeArray objectInfo)
{
var objectCounts = new NativeArray(m_LabelConfig.labelEntries.Count, Allocator.Temp);
foreach (var info in objectInfo)
{
if (!m_LabelConfig.TryGetLabelEntryFromInstanceId(info.instanceId, out _, out var labelIndex))
continue;
objectCounts[labelIndex]++;
}
return objectCounts;
}
void ProduceObjectCountMetric(NativeSlice counts, IReadOnlyList entries, int frameCount)
{
using (s_ClassCountCallback.Auto())
{
if (!m_AsyncMetrics.TryGetValue(frameCount, out var classCountAsyncMetric))
return;
m_AsyncMetrics.Remove(frameCount);
if (m_ClassCountValues == null || m_ClassCountValues.Length != entries.Count)
m_ClassCountValues = new ObjectCountMetric.Entry[entries.Count]; //ClassCountValue[entries.Count];
var visualize = visualizationEnabled;
if (visualize)
{
// Clear out all of the old entries...
hudPanel.RemoveEntries(this);
}
for (var i = 0; i < entries.Count; i++)
{
m_ClassCountValues[i] = new ObjectCountMetric.Entry
{
labelName = entries[i].label,
count = (int)counts[i]
};
// Only display entries with a count greater than 0
if (visualize && counts[i] > 0)
{
var label = entries[i].label + " Counts";
hudPanel.UpdateEntry(this, label, counts[i].ToString());
}
}
var payload = new ObjectCountMetric
{
sensorId = "",
annotationId = default,
description = m_Definition.description,
metadata = new Dictionary(),
objectCounts = m_ClassCountValues
};
classCountAsyncMetric.Report(payload);
}
}
///
protected override void OnVisualizerEnabledChanged(bool enabled)
{
if (enabled) return;
hudPanel.RemoveEntries(this);
}
}
}