浏览代码

Almost working using compute shaders. Need to switch to regular rendering now.

/keypoint_self_occlusion
Jon Hogins 4 年前
当前提交
ffab50bf
共有 1 个文件被更改,包括 91 次插入51 次删除
  1. 142
      com.unity.perception/Runtime/GroundTruth/Labelers/KeypointLabeler.cs

142
com.unity.perception/Runtime/GroundTruth/Labelers/KeypointLabeler.cs


using System.Collections.Generic;
using System.Linq;
using Unity.Collections;
using Unity.Collections.LowLevel.Unsafe;
using Unity.Mathematics;
using UnityEngine.Rendering;

AnnotationDefinition m_AnnotationDefinition;
Texture2D m_MissingTexture;
Dictionary<int, (AsyncAnnotation annotation, List<KeypointEntry> keypoints, List<float> zValues)> m_AsyncAnnotations;
struct FrameKeypointData
{
public AsyncAnnotation annotation;
public int pointsPerEntry;
public List<KeypointEntry> keypoints;
public bool isDepthCheckComplete;
public bool isInstanceSegmentationCheckComplete;
public NativeArray<RenderedObjectInfo> objectInfos;
}
Dictionary<int, FrameKeypointData> m_FrameKeypointData;
List<KeypointEntry> m_KeypointEntriesToReport;
int m_CurrentFrame;

m_KnownStatus = new Dictionary<uint, CachedData>();
m_AsyncAnnotations = new Dictionary<int, (AsyncAnnotation, Dictionary<uint, KeypointEntry>)>();
m_FrameKeypointData = new Dictionary<int, FrameKeypointData>();
m_KeypointEntriesToReport = new List<KeypointEntry>();
m_CurrentFrame = 0;
m_KeypointDepthTestShader = (ComputeShader) Resources.Load("KeypointDepthTest");

void OnInstanceSegmentationImageReadback(int frameCount, NativeArray<Color32> data, RenderTexture renderTexture)
{
if (!m_AsyncAnnotations.TryGetValue(frameCount, out var asyncAnnotation))
if (!m_FrameKeypointData.TryGetValue(frameCount, out var frameKeypointData))
foreach (var keypointEntry in asyncAnnotation.keypoints)
foreach (var keypointEntry in frameKeypointData.keypoints)
{
if (InstanceIdToColorMapping.TryGetColorFromInstanceId(keypointEntry.instance_id, out var idColor))
{

keypoint.y = math.clamp(keypoint.y, 0, dimensions.height - .001f);
}
keypointSet.Value.keypoints[i] = keypoint;
keypointEntry.keypoints[i] = keypoint;
frameKeypointData.isInstanceSegmentationCheckComplete = true;
m_FrameKeypointData[frameCount] = frameKeypointData;
ReportIfComplete(frameCount, frameKeypointData);
if (!m_AsyncAnnotations.TryGetValue(frameCount, out var asyncAnnotation))
if (!m_FrameKeypointData.TryGetValue(frameCount, out var frameKeypointData))
m_AsyncAnnotations.Remove(frameCount);
frameKeypointData.objectInfos = new NativeArray<RenderedObjectInfo>(objectInfos, Allocator.Persistent);
m_FrameKeypointData[frameCount] = frameKeypointData;
ReportIfComplete(frameCount, frameKeypointData);
}
private void ReportIfComplete(int frameCount, FrameKeypointData frameKeypointData)
{
if (!frameKeypointData.isInstanceSegmentationCheckComplete || !frameKeypointData.isDepthCheckComplete || !frameKeypointData.objectInfos.IsCreated)
return;
foreach (var entry in asyncAnnotation.keypoints)
foreach (var entry in frameKeypointData.keypoints)
{
var include = false;
if (objectFilter == KeypointObjectFilter.All)

foreach (var objectInfo in objectInfos)
foreach (var objectInfo in frameKeypointData.objectInfos)
{
if (entry.instance_id == objectInfo.instanceId)
{

if (!include && objectFilter == KeypointObjectFilter.VisibleAndOccluded)
include = entry.keypoints.Any(k => k.state == 1);
}
//This code assumes that OnRenderedObjectInfoReadback will be called immediately after OnInstanceSegmentationImageReadback
m_FrameKeypointData.Remove(frameCount);
asyncAnnotation.annotation.ReportValues(m_KeypointEntriesToReport);
}
struct KeypointDepthCheckData
{
public float3 position;
frameKeypointData.annotation.ReportValues(m_KeypointEntriesToReport);
frameKeypointData.objectInfos.Dispose();
}
/// <param name="scriptableRenderContext"></param>

var annotation = perceptionCamera.SensorHandle.ReportAnnotationAsync(m_AnnotationDefinition);
var keypointEntries = new List<KeypointEntry>();
var positions = new NativeList<float3>(512, Allocator.Persistent);
m_AsyncAnnotations[m_CurrentFrame] = (annotation, keypointEntries);
foreach (var label in LabelManager.singleton.registeredLabels)
ProcessLabel(label, keypointEntries, positions);
if (keypointEntries.Count != 0)
DoDepthCheck(scriptableRenderContext, keypointEntries, positions);
foreach (var label in LabelManager.singleton.registeredLabels)
ProcessLabel(label);
m_FrameKeypointData[m_CurrentFrame] = new FrameKeypointData
{
annotation = annotation,
keypoints = keypointEntries,
pointsPerEntry = activeTemplate.keypoints.Length
};
}
private void DoDepthCheck(ScriptableRenderContext scriptableRenderContext, List<KeypointEntry> keypointEntries, NativeList<float3> positions)
{
var keypointCount = keypointEntries.Count * activeTemplate.keypoints.Length;
var commandBuffer = CommandBufferPool.Get();

var keypointDepthCheckData =
new NativeArray<KeypointDepthCheckData>(keypointCount, Allocator.Temp, NativeArrayOptions.UninitializedMemory);
int index = 0;
foreach (var keypointEntry in keypointEntries)
{
foreach (var keypoint in keypointEntry.keypoints)
{
keypointDepthCheckData[index++] = new KeypointDepthCheckData()
{
position = new float3(keypoint.x, keypoint.y)
}
}
}
var positionSize = UnsafeUtility.SizeOf<float3>();
var keypointPositionsBuffer = new ComputeBuffer(keypointCount * positionSize, positionSize,
ComputeBufferType.Default, ComputeBufferMode.Dynamic);
keypointPositionsBuffer.SetData(positions.AsArray());
var keypointDataComputeBuffer = new ComputeBuffer(keypointCount, 4, ComputeBufferType.Default, ComputeBufferMode.Dynamic);
keypointDataComputeBuffer.SetData(keypointDepthCheckData);
var resultsComputeBuffer = new ComputeBuffer(keypointCount, 4, ComputeBufferType.Default, ComputeBufferMode.Dynamic);
var resultsComputeBuffer =
new ComputeBuffer(keypointCount, 4, ComputeBufferType.Default, ComputeBufferMode.Dynamic);
commandBuffer.SetComputeBufferParam(m_KeypointDepthTestShader, 0, "CheckPositions", keypointDataComputeBuffer);
commandBuffer.SetComputeBufferParam(m_KeypointDepthTestShader, 0, "CheckPositions", keypointPositionsBuffer);
commandBuffer.RequestAsyncReadback(resultsComputeBuffer, OnDepthCheckReadback);
var currentFrame = Time.frameCount;
commandBuffer.RequestAsyncReadback(resultsComputeBuffer, request => OnDepthCheckReadback(currentFrame, request));
scriptableRenderContext.Submit();
private void OnDepthCheckReadback(AsyncGPUReadbackRequest obj)
private void OnDepthCheckReadback(int frameCount, AsyncGPUReadbackRequest obj)
var frameKeypointData = m_FrameKeypointData[frameCount];
for (var i = 0; i < data.Length; i++)
{
var value = data[i];
if (value == 0)
{
var keypoints = frameKeypointData.keypoints[i / frameKeypointData.pointsPerEntry];
var indexInObject = i % frameKeypointData.pointsPerEntry;
var keypoint = keypoints.keypoints[indexInObject];
keypoint.state = 1;
keypoints.keypoints[indexInObject] = keypoint;
}
}
frameKeypointData.isDepthCheckComplete = true;
m_FrameKeypointData[frameCount] = frameKeypointData;
ReportIfComplete(frameCount, frameKeypointData);
}
// ReSharper disable InconsistentNaming

public bool status;
public Animator animator;
public KeypointEntry keypoints;
public List<float> zValues;
public List<(JointLabel, int)> overrides;
}

return false;
}
void ProcessLabel(Labeling labeledEntity)
void ProcessLabel(Labeling labeledEntity, List<KeypointEntry> keypointEntries, NativeList<float3> positions)
{
if (!idLabelConfig.TryGetLabelEntryFromInstanceId(labeledEntity.instanceId, out var labelEntry))
return;

status = false,
animator = null,
keypoints = new KeypointEntry(),
overrides = new List<(JointLabel, int)>(),
zValues = new List<float>(activeTemplate.keypoints.Length)
overrides = new List<(JointLabel, int)>()
};
var entityGameObject = labeledEntity.gameObject;

if (cachedData.status)
{
var animator = cachedData.animator;
var keypoints = cachedData.keypoints.keypoints;
var listStart = positions.Length;
positions.Resize(positions.Length + activeTemplate.keypoints.Length, NativeArrayOptions.ClearMemory);
//grab the slice of the list for the current object to assign positions in
var positionsSlice = new NativeSlice<float3>(positions, listStart);
// Go through all of the rig keypoints and get their location
for (var i = 0; i < activeTemplate.keypoints.Length; i++)

var bone = animator.GetBoneTransform(pt.rigLabel);
if (bone != null)
{
InitKeypoint(bone.position, cachedData, i);
InitKeypoint(bone.position, cachedData, positionsSlice, i);
}
}
}

foreach (var (joint, idx) in cachedData.overrides)
{
InitKeypoint(joint.transform.position, cachedData, idx);
InitKeypoint(joint.transform.position, cachedData, positionsSlice, idx);
}
cachedData.keypoints.pose = "unset";

cachedData.keypoints.pose = GetPose(cachedData.animator);
}
var cachedKeypointEntry = cachedData.keypoints;
var keypointEntry = new KeypointEntry()

pose = cachedKeypointEntry.pose,
template_guid = cachedKeypointEntry.template_guid
};
m_AsyncAnnotations[m_CurrentFrame].keypoints.Add(keypointEntry);
keypointEntries.Add(keypointEntry);
private void InitKeypoint(Vector3 position, CachedData cachedData, int idx)
private void InitKeypoint(Vector3 position, CachedData cachedData, NativeSlice<float3> positions, int idx)
cachedData.zValues[idx] = position.z;
positions[idx] = position;
var keypoints = cachedData.keypoints.keypoints;
keypoints[idx].index = idx;
if (loc.z < 0)

正在加载...
取消
保存