using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using Unity.Collections;
using Unity.Profiling;
using UnityEngine.UI;
namespace UnityEngine.Perception.GroundTruth
{
///
/// Labeler which produces object counts for each label in the associated each frame.
///
[Serializable]
public sealed class ObjectCountLabeler : CameraLabeler
{
///
/// The ID to use for object count annotations in the resulting dataset
///
public string objectCountMetricId = "51da3c27-369d-4929-aea6-d01614635ce2";
///
/// The which associates objects with labels.
///
public IdLabelConfig labelConfig => m_LabelConfig;
///
/// Fired when the object counts are computed for a frame.
///
public event Action,IReadOnlyList> ObjectCountsComputed;
[SerializeField]
IdLabelConfig m_LabelConfig;
static ProfilerMarker s_ClassCountCallback = new ProfilerMarker("OnClassLabelsReceived");
ClassCountValue[] m_ClassCountValues;
Dictionary m_ObjectCountAsyncMetrics;
MetricDefinition m_ObjectCountMetricDefinition;
List vizEntries = null;
///
/// Creates a new ObjectCountLabeler. This constructor should only be used by serialization. For creation from
/// user code, use .
///
public ObjectCountLabeler()
{
}
///
/// Creates a new ObjectCountLabeler with the given .
///
/// The label config for resolving the label for each object.
public ObjectCountLabeler(IdLabelConfig labelConfig)
{
if (labelConfig == null)
throw new ArgumentNullException(nameof(labelConfig));
m_LabelConfig = labelConfig;
}
[SuppressMessage("ReSharper", "InconsistentNaming")]
[SuppressMessage("ReSharper", "NotAccessedField.Local")]
struct ClassCountValue
{
public int label_id;
public string label_name;
public uint count;
}
///
protected override bool supportsVisualization => true;
///
protected override void Setup()
{
if (labelConfig == null)
throw new InvalidOperationException("The ObjectCountLabeler idLabelConfig field must be assigned");
m_ObjectCountAsyncMetrics = new Dictionary();
perceptionCamera.RenderedObjectInfosCalculated += (frameCount, objectInfo) =>
{
NativeArray objectCounts = ComputeObjectCounts(objectInfo);
ObjectCountsComputed?.Invoke(frameCount, objectCounts, labelConfig.labelEntries);
ProduceObjectCountMetric(objectCounts, m_LabelConfig.labelEntries, frameCount);
};
visualizationEnabled = supportsVisualization;
}
///
protected override void OnBeginRendering()
{
if (m_ObjectCountMetricDefinition.Equals(default))
{
m_ObjectCountMetricDefinition = DatasetCapture.RegisterMetricDefinition("object count",
m_LabelConfig.GetAnnotationSpecification(),
"Counts of objects for each label in the sensor's view", id: new Guid(objectCountMetricId));
}
m_ObjectCountAsyncMetrics[Time.frameCount] = perceptionCamera.SensorHandle.ReportMetricAsync(m_ObjectCountMetricDefinition);
}
NativeArray ComputeObjectCounts(NativeArray objectInfo)
{
var objectCounts = new NativeArray(m_LabelConfig.labelEntries.Count, Allocator.Temp);
foreach (var info in objectInfo)
{
if (!m_LabelConfig.TryGetLabelEntryFromInstanceId(info.instanceId, out _, out var labelIndex))
continue;
objectCounts[labelIndex]++;
}
return objectCounts;
}
void ProduceObjectCountMetric(NativeSlice counts, IReadOnlyList entries, int frameCount)
{
using (s_ClassCountCallback.Auto())
{
if (!m_ObjectCountAsyncMetrics.TryGetValue(frameCount, out var classCountAsyncMetric))
return;
m_ObjectCountAsyncMetrics.Remove(frameCount);
if (m_ClassCountValues == null || m_ClassCountValues.Length != entries.Count)
m_ClassCountValues = new ClassCountValue[entries.Count];
bool visualize = visualizationEnabled;
if (visualize && vizEntries == null)
{
vizEntries = new List();
}
for (var i = 0; i < entries.Count; i++)
{
m_ClassCountValues[i] = new ClassCountValue()
{
label_id = entries[i].id,
label_name = entries[i].label,
count = counts[i]
};
if (visualize)
{
var label = entries[i].label + " Counts";
hudPanel.UpdateEntry(label, counts[i].ToString());
vizEntries.Add(label);
}
}
classCountAsyncMetric.ReportValues(m_ClassCountValues);
}
}
///
protected override void PopulateVisualizationPanel(ControlPanel panel)
{
panel.AddToggleControl("Object Counts", enabled => { visualizationEnabled = enabled; });
}
///
override protected void OnVisualizerEnabledChanged(bool enabled)
{
if (!enabled)
{
hudPanel.RemoveEntries(vizEntries);
vizEntries.Clear();
}
}
}
}