using System.Collections.Generic; using UnityEngine; namespace Unity.MLAgents.Extensions.Sensors { /// /// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node, /// and child nodes are connect to their parents via Joints. /// public class RigidBodyPoseExtractor : PoseExtractor { Rigidbody[] m_Bodies; /// /// Initialize given a root RigidBody. /// /// public RigidBodyPoseExtractor(Rigidbody rootBody) { if (rootBody == null) { return; } var rbs = rootBody.GetComponentsInChildren (); var bodyToIndex = new Dictionary(rbs.Length); var parentIndices = new int[rbs.Length]; if (rbs[0] != rootBody) { Debug.Log("Expected root body at index 0"); return; } for (var i = 0; i < rbs.Length; i++) { bodyToIndex[rbs[i]] = i; } var joints = rootBody.GetComponentsInChildren (); foreach (var j in joints) { var parent = j.connectedBody; var child = j.GetComponent(); var parentIndex = bodyToIndex[parent]; var childIndex = bodyToIndex[child]; parentIndices[childIndex] = parentIndex; } m_Bodies = rbs; SetParentIndices(parentIndices); } /// /// Get the pose of the i'th RigidBody. /// /// /// protected override Pose GetPoseAt(int index) { var body = m_Bodies[index]; return new Pose { rotation = body.rotation, position = body.position }; } } }