using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Sensors
{
///
/// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node,
/// and child nodes are connect to their parents via Joints.
///
public class RigidBodyPoseExtractor : PoseExtractor
{
Rigidbody[] m_Bodies;
///
/// Optional game object used to determine the root of the poses, separate from the actual Rigidbodies
/// in the hierarchy. For locomotion
///
GameObject m_VirtualRoot;
///
/// Initialize given a root RigidBody.
///
/// The root Rigidbody. This has no Joints on it (but other Joints may connect to it).
/// Optional GameObject used to find Rigidbodies in the hierarchy.
/// Optional GameObject used to determine the root of the poses,
/// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
/// a stabilized reference frame, which can improve learning.
/// Optional mapping of whether a body's psoe should be enabled or not.
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null,
GameObject virtualRoot = null, Dictionary enableBodyPoses = null)
{
if (rootBody == null)
{
return;
}
Rigidbody[] rbs;
Joint[] joints;
if (rootGameObject == null)
{
rbs = rootBody.GetComponentsInChildren();
joints = rootBody.GetComponentsInChildren ();
}
else
{
rbs = rootGameObject.GetComponentsInChildren();
joints = rootGameObject.GetComponentsInChildren();
}
if (rbs == null || rbs.Length == 0)
{
Debug.Log("No rigid bodies found!");
return;
}
if (rbs[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;
}
// Adjust the array if we have a virtual root.
// This will be at index 0, and the "real" root will be parented to it.
if (virtualRoot != null)
{
var extendedRbs = new Rigidbody[rbs.Length + 1];
for (var i = 0; i < rbs.Length; i++)
{
extendedRbs[i + 1] = rbs[i];
}
rbs = extendedRbs;
}
var bodyToIndex = new Dictionary(rbs.Length);
var parentIndices = new int[rbs.Length];
parentIndices[0] = -1;
for (var i = 0; i < rbs.Length; i++)
{
if(rbs[i] != null)
{
bodyToIndex[rbs[i]] = i;
}
}
foreach (var j in joints)
{
var parent = j.connectedBody;
var child = j.GetComponent();
var parentIndex = bodyToIndex[parent];
var childIndex = bodyToIndex[child];
parentIndices[childIndex] = parentIndex;
}
if (virtualRoot != null)
{
// Make sure the original root treats the virtual root as its parent.
parentIndices[1] = 0;
m_VirtualRoot = virtualRoot;
}
m_Bodies = rbs;
Setup(parentIndices);
// By default, ignore the root
SetPoseEnabled(0, false);
if (enableBodyPoses != null)
{
foreach (var pair in enableBodyPoses)
{
var rb = pair.Key;
if (bodyToIndex.TryGetValue(rb, out var index))
{
SetPoseEnabled(index, pair.Value);
}
}
}
}
///
protected internal override Vector3 GetLinearVelocityAt(int index)
{
if (index == 0 && m_VirtualRoot != null)
{
// No velocity on the virtual root
return Vector3.zero;
}
return m_Bodies[index].velocity;
}
///
protected internal override Pose GetPoseAt(int index)
{
if (index == 0 && m_VirtualRoot != null)
{
// Use the GameObject's world transform
return new Pose
{
rotation = m_VirtualRoot.transform.rotation,
position = m_VirtualRoot.transform.position
};
}
var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}
///
protected internal override Object GetObjectAt(int index)
{
if (index == 0 && m_VirtualRoot != null)
{
return m_VirtualRoot;
}
return m_Bodies[index];
}
internal Rigidbody[] Bodies => m_Bodies;
///
/// Get a dictionary indicating which Rigidbodies' poses are enabled or disabled.
///
///
internal Dictionary GetBodyPosesEnabled()
{
var bodyPosesEnabled = new Dictionary(m_Bodies.Length);
for (var i = 0; i < m_Bodies.Length; i++)
{
var rb = m_Bodies[i];
if (rb == null)
{
continue; // skip virtual root
}
bodyPosesEnabled[rb] = IsPoseEnabled(i);
}
return bodyPosesEnabled;
}
internal IEnumerable GetEnabledRigidbodies()
{
if (m_Bodies == null)
{
yield break;
}
for (var i = 0; i < m_Bodies.Length; i++)
{
var rb = m_Bodies[i];
if (rb == null)
{
// Ignore a virtual root.
continue;
}
if (IsPoseEnabled(i))
{
yield return rb;
}
}
}
}
}