浏览代码

[MLA-240] Physic-based pose generation (#4166)

* hierarchy POC

* WIP

* abstract class

* clean up init

* separate files

* cleanup, unit test

* add Articulation util

* sensor WIP

* ArticulationBody sensor, starting docs

* docstrings, cleanup

* hierarchy tests, transform operators

* unit tests

* use Pose struct instead

* delete QTTransform

* rename

* renames and compile fixes

* remove ArticulationBodySensor* for now

* revert CrawlerAgent changes

* rename
/MLA-1734-demo-provider
GitHub 5 年前
当前提交
fc6bc4db
共有 20 个文件被更改,包括 665 次插入28 次删除
  1. 2
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta
  2. 2
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta
  3. 8
      com.unity.ml-agents.extensions/Runtime/Sensors.meta
  4. 8
      com.unity.ml-agents.extensions/Tests/Editor/Sensors.meta
  5. 60
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  6. 127
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
  7. 170
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  8. 3
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs.meta
  9. 70
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  10. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs.meta
  11. 122
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
  12. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs.meta
  13. 62
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
  14. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs.meta
  15. 6
      com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs
  16. 20
      com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs
  17. 0
      /com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta
  18. 0
      /com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta

2
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta


fileFormatVersion: 2
guid: 2763a568e3b8541e08b90cb15e442281
guid: 11fe037a02b4a483cb9342c3454232cd
MonoImporter:
externalObjects: {}
serializedVersion: 2

2
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta


fileFormatVersion: 2
guid: 8dadec7ee45de484bac4cab4efcca17d
guid: fcb7a51f0d5f8404db7b85bd35ecc1fb
MonoImporter:
externalObjects: {}
serializedVersion: 2

8
com.unity.ml-agents.extensions/Runtime/Sensors.meta


fileFormatVersion: 2
guid: 2a66e31170bb04777b9ade862995a624
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
com.unity.ml-agents.extensions/Tests/Editor/Sensors.meta


fileFormatVersion: 2
guid: b7ffdca5cd8064ee6831175d7ffd3f0f
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

60
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


#if UNITY_2020_1_OR_NEWER
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Sensors
{
public class ArticulationBodyPoseExtractor : PoseExtractor
{
ArticulationBody[] m_Bodies;
public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
{
if (!rootBody.isRoot)
{
Debug.Log("Must pass ArticulationBody.isRoot");
return;
}
var bodies = rootBody.GetComponentsInChildren <ArticulationBody>();
if (bodies[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;
}
var numBodies = bodies.Length;
m_Bodies = bodies;
int[] parentIndices = new int[numBodies];
parentIndices[0] = -1;
var bodyToIndex = new Dictionary<ArticulationBody, int>();
for (var i = 0; i < numBodies; i++)
{
bodyToIndex[m_Bodies[i]] = i;
}
for (var i = 1; i < numBodies; i++)
{
var body = m_Bodies[i];
var parent = body.GetComponentInParent<ArticulationBody>();
parentIndices[i] = bodyToIndex[parent];
}
SetParentIndices(parentIndices);
}
protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
var go = body.gameObject;
var t = go.transform;
return new Pose { rotation = t.rotation, position = t.position };
}
}
}
#endif // UNITY_2020_1_OR_NEWER

127
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs


using System;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
[Serializable]
public struct PhysicsSensorSettings
{
/// <summary>
/// Whether to use model space (relative to the root body) translations as observations.
/// </summary>
public bool UseModelSpaceTranslations;
/// <summary>
/// Whether to use model space (relative to the root body) rotatoins as observations.
/// </summary>
public bool UseModelSpaceRotations;
/// <summary>
/// Whether to use local space (relative to the parent body) translations as observations.
/// </summary>
public bool UseLocalSpaceTranslations;
/// <summary>
/// Whether to use local space (relative to the parent body) translations as observations.
/// </summary>
public bool UseLocalSpaceRotations;
/// <summary>
/// Creates a PhysicsSensorSettings with reasonable default values.
/// </summary>
/// <returns></returns>
public static PhysicsSensorSettings Default()
{
return new PhysicsSensorSettings
{
UseModelSpaceTranslations = true,
UseModelSpaceRotations = true,
};
}
/// <summary>
/// Whether any model space observations are being used.
/// </summary>
public bool UseModelSpace
{
get { return UseModelSpaceTranslations || UseModelSpaceRotations; }
}
/// <summary>
/// Whether any local space observations are being used.
/// </summary>
public bool UseLocalSpace
{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; }
}
/// <summary>
/// The number of floats needed to represent a given number of transforms.
/// </summary>
/// <param name="numTransforms"></param>
/// <returns></returns>
public int TransformSize(int numTransforms)
{
int obsPerTransform = 0;
obsPerTransform += UseModelSpaceTranslations ? 3 : 0;
obsPerTransform += UseModelSpaceRotations ? 4 : 0;
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
return numTransforms * obsPerTransform;
}
}
internal static class ObservationWriterPhysicsExtensions
{
/// <summary>
/// Utility method for writing a PoseExtractor to an ObservationWriter.
/// </summary>
/// <param name="writer"></param>
/// <param name="settings"></param>
/// <param name="poseExtractor"></param>
/// <param name="baseOffset">The offset into the ObservationWriter to start writing at.</param>
/// <returns>The number of observations written.</returns>
public static int WritePoses(this ObservationWriter writer, PhysicsSensorSettings settings, PoseExtractor poseExtractor, int baseOffset = 0)
{
var offset = baseOffset;
if (settings.UseModelSpace)
{
foreach (var pose in poseExtractor.ModelSpacePoses)
{
if(settings.UseModelSpaceTranslations)
{
writer.Add(pose.position, offset);
offset += 3;
}
if (settings.UseModelSpaceRotations)
{
writer.Add(pose.rotation, offset);
offset += 4;
}
}
}
if (settings.UseLocalSpace)
{
foreach (var pose in poseExtractor.LocalSpacePoses)
{
if(settings.UseLocalSpaceTranslations)
{
writer.Add(pose.position, offset);
offset += 3;
}
if (settings.UseLocalSpaceRotations)
{
writer.Add(pose.rotation, offset);
offset += 4;
}
}
}
return offset - baseOffset;
}
}
}

170
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Abstract class for managing the transforms of a hierarchy of objects.
/// This could be GameObjects or Monobehaviours in the scene graph, but this is
/// not a requirement; for example, the objects could be rigid bodies whose hierarchy
/// is defined by Joint configurations.
///
/// Poses are either considered in model space, which is relative to a root body,
/// or in local space, which is relative to their parent.
/// </summary>
public abstract class PoseExtractor
{
int[] m_ParentIndices;
Pose[] m_ModelSpacePoses;
Pose[] m_LocalSpacePoses;
/// <summary>
/// Read access to the model space transforms.
/// </summary>
public IList<Pose> ModelSpacePoses
{
get { return m_ModelSpacePoses; }
}
/// <summary>
/// Read access to the local space transforms.
/// </summary>
public IList<Pose> LocalSpacePoses
{
get { return m_LocalSpacePoses; }
}
/// <summary>
/// Number of transforms in the hierarchy (read-only).
/// </summary>
public int NumPoses
{
get { return m_ModelSpacePoses?.Length ?? 0; }
}
/// <summary>
/// Initialize with the mapping of parent indices.
/// The 0th element is assumed to be -1, indicating that it's the root.
/// </summary>
/// <param name="parentIndices"></param>
protected void SetParentIndices(int[] parentIndices)
{
m_ParentIndices = parentIndices;
var numTransforms = parentIndices.Length;
m_ModelSpacePoses = new Pose[numTransforms];
m_LocalSpacePoses = new Pose[numTransforms];
}
/// <summary>
/// Return the world space Pose of the i'th object.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected abstract Pose GetPoseAt(int index);
/// <summary>
/// Update the internal model space transform storage based on the underlying system.
/// </summary>
public void UpdateModelSpacePoses()
{
if (m_ModelSpacePoses == null)
{
return;
}
var worldTransform = GetPoseAt(0);
var worldToModel = worldTransform.Inverse();
for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
var currentTransform = GetPoseAt(i);
m_ModelSpacePoses[i] = worldToModel.Multiply(currentTransform);
}
}
/// <summary>
/// Update the internal model space transform storage based on the underlying system.
/// </summary>
public void UpdateLocalSpacePoses()
{
if (m_LocalSpacePoses == null)
{
return;
}
for (var i = 0; i < m_LocalSpacePoses.Length; i++)
{
if (m_ParentIndices[i] != -1)
{
var parentTransform = GetPoseAt(m_ParentIndices[i]);
// This is slightly inefficient, since for a body with multiple children, we'll end up inverting
// the transform multiple times. Might be able to trade space for perf here.
var invParent = parentTransform.Inverse();
var currentTransform = GetPoseAt(i);
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
}
else
{
m_LocalSpacePoses[i] = Pose.identity;
}
}
}
public void DrawModelSpace(Vector3 offset)
{
UpdateLocalSpacePoses();
UpdateModelSpacePoses();
var pose = m_ModelSpacePoses;
var localPose = m_LocalSpacePoses;
for (var i = 0; i < pose.Length; i++)
{
var current = pose[i];
if (m_ParentIndices[i] == -1)
{
continue;
}
var parent = pose[m_ParentIndices[i]];
Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan);
var localUp = localPose[i].rotation * Vector3.up;
var localFwd = localPose[i].rotation * Vector3.forward;
var localRight = localPose[i].rotation * Vector3.right;
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localUp, Color.red);
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localFwd, Color.green);
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue);
}
}
}
public static class PoseExtensions
{
/// <summary>
/// Compute the inverse of a Pose. For any Pose P,
/// P.Inverse() * P
/// will equal the identity pose (within tolerance).
/// </summary>
/// <param name="pose"></param>
/// <returns></returns>
public static Pose Inverse(this Pose pose)
{
var rotationInverse = Quaternion.Inverse(pose.rotation);
var translationInverse = -(rotationInverse * pose.position);
return new Pose { rotation = rotationInverse, position = translationInverse };
}
/// <summary>
/// This is equivalent to Pose.GetTransformedBy(), but keeps the order more intuitive.
/// </summary>
/// <param name="pose"></param>
/// <param name="rhs"></param>
/// <returns></returns>
public static Pose Multiply(this Pose pose, Pose rhs)
{
return rhs.GetTransformedBy(pose);
}
// TODO optimize inv(A)*B?
}
}

3
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs.meta


fileFormatVersion: 2
guid: e0f1e5147a394f428f9a6447c4a8a1f4
timeCreated: 1591919825

70
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs


using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node,
/// and child nodes are connect to their parents via Joints.
/// </summary>
public class RigidBodyPoseExtractor : PoseExtractor
{
Rigidbody[] m_Bodies;
/// <summary>
/// Initialize given a root RigidBody.
/// </summary>
/// <param name="rootBody"></param>
public RigidBodyPoseExtractor(Rigidbody rootBody)
{
if (rootBody == null)
{
return;
}
var rbs = rootBody.GetComponentsInChildren <Rigidbody>();
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length);
var parentIndices = new int[rbs.Length];
if (rbs[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;
}
for (var i = 0; i < rbs.Length; i++)
{
bodyToIndex[rbs[i]] = i;
}
var joints = rootBody.GetComponentsInChildren <Joint>();
foreach (var j in joints)
{
var parent = j.connectedBody;
var child = j.GetComponent<Rigidbody>();
var parentIndex = bodyToIndex[parent];
var childIndex = bodyToIndex[child];
parentIndices[childIndex] = parentIndex;
}
m_Bodies = rbs;
SetParentIndices(parentIndices);
}
/// <summary>
/// Get the pose of the i'th RigidBody.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}
}
}

11
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs.meta


fileFormatVersion: 2
guid: 867cab4a07f244518bae3e6fdda14416
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

122
com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs


using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class PoseExtractorTests
{
class UselessPoseExtractor : PoseExtractor
{
protected override Pose GetPoseAt(int index)
{
return Pose.identity;
}
public void Init(int[] parentIndices)
{
SetParentIndices(parentIndices);
}
}
[Test]
public void TestEmptyExtractor()
{
var poseExtractor = new UselessPoseExtractor();
// These should be no-ops
poseExtractor.UpdateLocalSpacePoses();
poseExtractor.UpdateModelSpacePoses();
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]
public void TestSimpleExtractor()
{
var poseExtractor = new UselessPoseExtractor();
var parentIndices = new[] { -1, 0 };
poseExtractor.Init(parentIndices);
Assert.AreEqual(2, poseExtractor.NumPoses);
}
/// <summary>
/// A simple "chain" hierarchy, where each object is parented to the one before it.
/// 0 <- 1 <- 2 <- ...
/// </summary>
class ChainPoseExtractor : PoseExtractor
{
public Vector3 offset;
public ChainPoseExtractor(int size)
{
var parents = new int[size];
for (var i = 0; i < size; i++)
{
parents[i] = i - 1;
}
SetParentIndices(parents);
}
protected override Pose GetPoseAt(int index)
{
var rotation = Quaternion.identity;
var translation = offset + new Vector3(index, index, index);
return new Pose
{
rotation = rotation,
position = translation
};
}
}
[Test]
public void TestChain()
{
var size = 4;
var chain = new ChainPoseExtractor(size);
chain.offset = new Vector3(.5f, .75f, .333f);
chain.UpdateModelSpacePoses();
chain.UpdateLocalSpacePoses();
// Root transforms are currently always the identity.
Assert.IsTrue(chain.ModelSpacePoses[0] == Pose.identity);
Assert.IsTrue(chain.LocalSpacePoses[0] == Pose.identity);
// Check the non-root transforms
for (var i = 1; i < size; i++)
{
var modelSpace = chain.ModelSpacePoses[i];
var expectedModelTranslation = new Vector3(i, i, i);
Assert.IsTrue(expectedModelTranslation == modelSpace.position);
var localSpace = chain.LocalSpacePoses[i];
var expectedLocalTranslation = new Vector3(1, 1, 1);
Assert.IsTrue(expectedLocalTranslation == localSpace.position);
}
}
}
public class PoseExtensionTests
{
[Test]
public void TestInverse()
{
Pose t = new Pose
{
rotation = Quaternion.AngleAxis(23.0f, new Vector3(1, 1, 1).normalized),
position = new Vector3(-1.0f, 2.0f, 3.0f)
};
var inverseT = t.Inverse();
var product = inverseT.Multiply(t);
Assert.IsTrue(Vector3.zero == product.position);
Assert.IsTrue(Quaternion.identity == product.rotation);
Assert.IsTrue(Pose.identity == product);
}
}
}

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs.meta


fileFormatVersion: 2
guid: 9b0deb29b2f5d4b03a7f75a516943c81
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

62
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs


using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class RigidBodyPoseExtractorTests
{
[TearDown]
public void RemoveGameObjects()
{
var objects = GameObject.FindObjectsOfType<GameObject>();
foreach (var o in objects)
{
UnityEngine.Object.DestroyImmediate(o);
}
}
[Test]
public void TestNullRoot()
{
var poseExtractor = new RigidBodyPoseExtractor(null);
// These should be no-ops
poseExtractor.UpdateLocalSpacePoses();
poseExtractor.UpdateModelSpacePoses();
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]
public void TestSingleBody()
{
var go = new GameObject();
var rootRb = go.AddComponent<Rigidbody>();
var poseExtractor = new RigidBodyPoseExtractor(rootRb);
Assert.AreEqual(1, poseExtractor.NumPoses);
}
[Test]
public void TestTwoBodies()
{
// * rootObj
// - rb1
// * go2
// - rb2
// - joint
var rootObj = new GameObject();
var rb1 = rootObj.AddComponent<Rigidbody>();
var go2 = new GameObject();
var rb2 = go2.AddComponent<Rigidbody>();
go2.transform.SetParent(rootObj.transform);
var joint = go2.AddComponent<ConfigurableJoint>();
joint.connectedBody = rb1;
var poseExtractor = new RigidBodyPoseExtractor(rb1);
Assert.AreEqual(2, poseExtractor.NumPoses);
}
}
}

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs.meta


fileFormatVersion: 2
guid: 23f57b248aaf940a3962cef68c6d83f5
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

6
com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs


using Unity.MLAgents;
namespace Unity.MLAgents.Extensions
{
}

20
com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs


using UnityEngine;
using UnityEditor;
using UnityEngine.TestTools;
using NUnit.Framework;
using System.Collections;
namespace Unity.MLAgents.Extensions.Tests
{
internal class EditorExampleTest {
[Test]
public void EditorTestMath()
{
Assert.AreEqual(2, 1 + 1);
}
}
}

/com.unity.ml-agents.extensions/Runtime/RuntimeExample.cs.meta → /com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta

/com.unity.ml-agents.extensions/Tests/Editor/EditorExampleTest.cs.meta → /com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta

正在加载...
取消
保存