using System.Collections.Generic; using UnityEngine; namespace Unity.MLAgents.Extensions.Sensors { /// /// Abstract class for managing the transforms of a hierarchy of objects. /// This could be GameObjects or Monobehaviours in the scene graph, but this is /// not a requirement; for example, the objects could be rigid bodies whose hierarchy /// is defined by Joint configurations. /// /// Poses are either considered in model space, which is relative to a root body, /// or in local space, which is relative to their parent. /// public abstract class PoseExtractor { int[] m_ParentIndices; Pose[] m_ModelSpacePoses; Pose[] m_LocalSpacePoses; Vector3[] m_ModelSpaceLinearVelocities; Vector3[] m_LocalSpaceLinearVelocities; bool[] m_PoseEnabled; /// /// Read iterator for the enabled model space transforms. /// public IEnumerable GetEnabledModelSpacePoses() { if (m_ModelSpacePoses == null) { yield break; } for (var i = 0; i < m_ModelSpacePoses.Length; i++) { if (m_PoseEnabled[i]) { yield return m_ModelSpacePoses[i]; } } } /// /// Read iterator for the enabled local space transforms. /// public IEnumerable GetEnabledLocalSpacePoses() { if (m_LocalSpacePoses == null) { yield break; } for (var i = 0; i < m_LocalSpacePoses.Length; i++) { if (m_PoseEnabled[i]) { yield return m_LocalSpacePoses[i]; } } } /// /// Read iterator for the enabled model space linear velocities. /// public IEnumerable GetEnabledModelSpaceVelocities() { if (m_ModelSpaceLinearVelocities == null) { yield break; } for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++) { if (m_PoseEnabled[i]) { yield return m_ModelSpaceLinearVelocities[i]; } } } /// /// Read iterator for the enabled local space linear velocities. /// public IEnumerable GetEnabledLocalSpaceVelocities() { if (m_LocalSpaceLinearVelocities == null) { yield break; } for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++) { if (m_PoseEnabled[i]) { yield return m_LocalSpaceLinearVelocities[i]; } } } /// /// Number of enabled poses in the hierarchy (read-only). /// public int NumEnabledPoses { get { if (m_PoseEnabled == null) { return 0; } var numEnabled = 0; for (var i = 0; i < m_PoseEnabled.Length; i++) { numEnabled += m_PoseEnabled[i] ? 1 : 0; } return numEnabled; } } /// /// Number of total poses in the hierarchy (read-only). /// public int NumPoses { get { return m_ModelSpacePoses?.Length ?? 0; } } /// /// Get the parent index of the body at the specified index. /// /// /// public int GetParentIndex(int index) { if (m_ParentIndices == null) { return -1; } return m_ParentIndices[index]; } /// /// Set whether the pose at the given index is enabled or disabled for observations. /// /// /// public void SetPoseEnabled(int index, bool val) { m_PoseEnabled[index] = val; } /// /// Initialize with the mapping of parent indices. /// The 0th element is assumed to be -1, indicating that it's the root. /// /// protected void Setup(int[] parentIndices) { #if DEBUG if (parentIndices[0] != -1) { throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}"); } #endif m_ParentIndices = parentIndices; var numPoses = parentIndices.Length; m_ModelSpacePoses = new Pose[numPoses]; m_LocalSpacePoses = new Pose[numPoses]; m_ModelSpaceLinearVelocities = new Vector3[numPoses]; m_LocalSpaceLinearVelocities = new Vector3[numPoses]; m_PoseEnabled = new bool[numPoses]; // All poses are enabled by default. Generally we'll want to disable the root though. for (var i = 0; i < numPoses; i++) { m_PoseEnabled[i] = true; } } /// /// Return the world space Pose of the i'th object. /// /// /// protected internal abstract Pose GetPoseAt(int index); /// /// Return the world space linear velocity of the i'th object. /// /// /// protected internal abstract Vector3 GetLinearVelocityAt(int index); /// /// Update the internal model space transform storage based on the underlying system. /// public void UpdateModelSpacePoses() { using (TimerStack.Instance.Scoped("UpdateModelSpacePoses")) { if (m_ModelSpacePoses == null) { return; } var rootWorldTransform = GetPoseAt(0); var worldToModel = rootWorldTransform.Inverse(); var rootLinearVel = GetLinearVelocityAt(0); for (var i = 0; i < m_ModelSpacePoses.Length; i++) { var currentWorldSpacePose = GetPoseAt(i); var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); m_ModelSpacePoses[i] = currentModelSpacePose; var currentBodyLinearVel = GetLinearVelocityAt(i); var relativeVelocity = currentBodyLinearVel - rootLinearVel; m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; } } } /// /// Update the internal model space transform storage based on the underlying system. /// public void UpdateLocalSpacePoses() { using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses")) { if (m_LocalSpacePoses == null) { return; } for (var i = 0; i < m_LocalSpacePoses.Length; i++) { if (m_ParentIndices[i] != -1) { var parentTransform = GetPoseAt(m_ParentIndices[i]); // This is slightly inefficient, since for a body with multiple children, we'll end up inverting // the transform multiple times. Might be able to trade space for perf here. var invParent = parentTransform.Inverse(); var currentTransform = GetPoseAt(i); m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); var currentLinearVel = GetLinearVelocityAt(i); m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); } else { m_LocalSpacePoses[i] = Pose.identity; m_LocalSpaceLinearVelocities[i] = Vector3.zero; } } } } /// /// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings. /// /// /// public int GetNumPoseObservations(PhysicsSensorSettings settings) { int obsPerPose = 0; obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0; obsPerPose += settings.UseModelSpaceRotations ? 4 : 0; obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0; obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; return NumEnabledPoses * obsPerPose; } internal void DrawModelSpace(Vector3 offset) { UpdateLocalSpacePoses(); UpdateModelSpacePoses(); var pose = m_ModelSpacePoses; var localPose = m_LocalSpacePoses; for (var i = 0; i < pose.Length; i++) { var current = pose[i]; if (m_ParentIndices[i] == -1) { continue; } var parent = pose[m_ParentIndices[i]]; Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan); var localUp = localPose[i].rotation * Vector3.up; var localFwd = localPose[i].rotation * Vector3.forward; var localRight = localPose[i].rotation * Vector3.right; Debug.DrawLine(current.position+offset, current.position+offset+.1f*localUp, Color.red); Debug.DrawLine(current.position+offset, current.position+offset+.1f*localFwd, Color.green); Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue); } } } /// /// Extension methods for the Pose struct, in order to improve the readability of some math. /// public static class PoseExtensions { /// /// Compute the inverse of a Pose. For any Pose P, /// P.Inverse() * P /// will equal the identity pose (within tolerance). /// /// /// public static Pose Inverse(this Pose pose) { var rotationInverse = Quaternion.Inverse(pose.rotation); var translationInverse = -(rotationInverse * pose.position); return new Pose { rotation = rotationInverse, position = translationInverse }; } /// /// This is equivalent to Pose.GetTransformedBy(), but keeps the order more intuitive. /// /// /// /// public static Pose Multiply(this Pose pose, Pose rhs) { return rhs.GetTransformedBy(pose); } /// /// Transform the vector by the pose. Conceptually this is equivalent to treating the Pose /// as a 4x4 matrix and multiplying the augmented vector. /// See https://en.wikipedia.org/wiki/Affine_transformation#Augmented_matrix for more details. /// /// /// /// public static Vector3 Multiply(this Pose pose, Vector3 rhs) { return pose.rotation * rhs + pose.position; } // TODO optimize inv(A)*B? } }