浏览代码

[MLA-1141] Rigidbody and ArticulationBody sensors (#4192)

/MLA-1734-demo-provider
GitHub 5 年前
当前提交
ec006f31
共有 23 个文件被更改,包括 799 次插入48 次删除
  1. 2
      com.unity.ml-agents.extensions/LICENSE.md
  2. 28
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  3. 44
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
  4. 85
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  5. 28
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  6. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
  7. 1
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
  8. 8
      com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef
  9. 22
      com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
  10. 41
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
  11. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta
  12. 93
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  13. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta
  14. 52
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  15. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta
  16. 63
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs
  17. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta
  18. 113
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
  19. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta
  20. 113
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
  21. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta
  22. 66
      com.unity.ml-agents/Runtime/SensorHelper.cs
  23. 11
      com.unity.ml-agents/Runtime/SensorHelper.cs.meta

2
com.unity.ml-agents.extensions/LICENSE.md


com.unity.ml-agents.extensions copyright © 2020 Unity Technologies
com.unity.ml-agents.extensions copyright © 2020 Unity Technologies ApS
Licensed under the Unity Companion License for Unity-dependent projects -- see
[Unity Companion License](http://www.unity3d.com/legal/licenses/Unity_Companion_License).

28
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Utility class to track a hierarchy of ArticulationBodies.
/// </summary>
public class ArticulationBodyPoseExtractor : PoseExtractor
{
ArticulationBody[] m_Bodies;

if (rootBody == null)
{
return;
}
if (!rootBody.isRoot)
{
Debug.Log("Must pass ArticulationBody.isRoot");

for (var i = 1; i < numBodies; i++)
{
var body = m_Bodies[i];
var parent = body.GetComponentInParent<ArticulationBody>();
parentIndices[i] = bodyToIndex[parent];
var currentArticBody = m_Bodies[i];
// Component.GetComponentInParent will consider the provided object as well.
// So start looking from the parent.
var currentGameObject = currentArticBody.gameObject;
var parentGameObject = currentGameObject.transform.parent;
var parentArticBody = parentGameObject.GetComponentInParent<ArticulationBody>();
parentIndices[i] = bodyToIndex[parentArticBody];
/// <inheritdoc/>
protected override Vector3 GetLinearVelocityAt(int index)
{
return m_Bodies[index].velocity;
}
/// <inheritdoc/>
protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];

}
}
}
#endif // UNITY_2020_1_OR_NEWER

44
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs


namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Settings that define the observations generated for physics-based sensors.
/// </summary>
[Serializable]
public struct PhysicsSensorSettings
{

public bool UseModelSpaceTranslations;
/// <summary>
/// Whether to use model space (relative to the root body) rotatoins as observations.
/// Whether to use model space (relative to the root body) rotations as observations.
/// </summary>
public bool UseModelSpaceRotations;

public bool UseLocalSpaceRotations;
/// <summary>
/// Whether to use model space (relative to the root body) linear velocities as observations.
/// </summary>
public bool UseModelSpaceLinearVelocity;
/// <summary>
/// Whether to use local space (relative to the parent body) linear velocities as observations.
/// </summary>
public bool UseLocalSpaceLinearVelocity;
/// <summary>
/// Creates a PhysicsSensorSettings with reasonable default values.
/// </summary>
/// <returns></returns>

/// </summary>
public bool UseModelSpace
{
get { return UseModelSpaceTranslations || UseModelSpaceRotations; }
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; }
}
/// <summary>

{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; }
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
}

obsPerTransform += UseModelSpaceRotations ? 4 : 0;
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;
return numTransforms * obsPerTransform;
}

var offset = baseOffset;
if (settings.UseModelSpace)
{
foreach (var pose in poseExtractor.ModelSpacePoses)
var poses = poseExtractor.ModelSpacePoses;
var vels = poseExtractor.ModelSpaceVelocities;
for(var i=0; i<poseExtractor.NumPoses; i++)
var pose = poses[i];
if(settings.UseModelSpaceTranslations)
{
writer.Add(pose.position, offset);

{
writer.Add(pose.rotation, offset);
offset += 4;
}
if (settings.UseModelSpaceLinearVelocity)
{
writer.Add(vels[i], offset);
offset += 3;
}
}
}

foreach (var pose in poseExtractor.LocalSpacePoses)
var poses = poseExtractor.LocalSpacePoses;
var vels = poseExtractor.LocalSpaceVelocities;
for(var i=0; i<poseExtractor.NumPoses; i++)
var pose = poses[i];
if(settings.UseLocalSpaceTranslations)
{
writer.Add(pose.position, offset);

{
writer.Add(pose.rotation, offset);
offset += 4;
}
if (settings.UseLocalSpaceLinearVelocity)
{
writer.Add(vels[i], offset);
offset += 3;
}
}
}

85
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


Pose[] m_ModelSpacePoses;
Pose[] m_LocalSpacePoses;
Vector3[] m_ModelSpaceLinearVelocities;
Vector3[] m_LocalSpaceLinearVelocities;
/// <summary>
/// Read access to the model space transforms.
/// </summary>

}
/// <summary>
/// Number of transforms in the hierarchy (read-only).
/// Read access to the model space linear velocities.
/// </summary>
public IList<Vector3> ModelSpaceVelocities
{
get { return m_ModelSpaceLinearVelocities; }
}
/// <summary>
/// Read access to the local space linear velocities.
/// </summary>
public IList<Vector3> LocalSpaceVelocities
{
get { return m_LocalSpaceLinearVelocities; }
}
/// <summary>
/// Number of poses in the hierarchy (read-only).
/// </summary>
public int NumPoses
{

/// <summary>
/// Get the parent index of the body at the specified index.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public int GetParentIndex(int index)
{
if (m_ParentIndices == null)
{
return -1;
}
return m_ParentIndices[index];
}
/// <summary>
/// Initialize with the mapping of parent indices.
/// The 0th element is assumed to be -1, indicating that it's the root.
/// </summary>

var numTransforms = parentIndices.Length;
m_ModelSpacePoses = new Pose[numTransforms];
m_LocalSpacePoses = new Pose[numTransforms];
m_ModelSpaceLinearVelocities = new Vector3[numTransforms];
m_LocalSpaceLinearVelocities = new Vector3[numTransforms];
}
/// <summary>

protected abstract Pose GetPoseAt(int index);
/// <summary>
/// Return the world space linear velocity of the i'th object.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected abstract Vector3 GetLinearVelocityAt(int index);
/// <summary>
/// Update the internal model space transform storage based on the underlying system.
/// </summary>
public void UpdateModelSpacePoses()

return;
}
var worldTransform = GetPoseAt(0);
var worldToModel = worldTransform.Inverse();
var rootWorldTransform = GetPoseAt(0);
var worldToModel = rootWorldTransform.Inverse();
var rootLinearVel = GetLinearVelocityAt(0);
var currentTransform = GetPoseAt(i);
m_ModelSpacePoses[i] = worldToModel.Multiply(currentTransform);
var currentWorldSpacePose = GetPoseAt(i);
var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
m_ModelSpacePoses[i] = currentModelSpacePose;
var currentBodyLinearVel = GetLinearVelocityAt(i);
var relativeVelocity = currentBodyLinearVel - rootLinearVel;
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
}
}

var invParent = parentTransform.Inverse();
var currentTransform = GetPoseAt(i);
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
var currentLinearVel = GetLinearVelocityAt(i);
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
m_LocalSpaceLinearVelocities[i] = Vector3.zero;
public void DrawModelSpace(Vector3 offset)
internal void DrawModelSpace(Vector3 offset)
{
UpdateLocalSpacePoses();
UpdateModelSpacePoses();

}
}
/// <summary>
/// Extension methods for the Pose struct, in order to improve the readability of some math.
/// </summary>
public static class PoseExtensions
{
/// <summary>

public static Pose Multiply(this Pose pose, Pose rhs)
{
return rhs.GetTransformedBy(pose);
}
/// <summary>
/// Transform the vector by the pose. Conceptually this is equivalent to treating the Pose
/// as a 4x4 matrix and multiplying the augmented vector.
/// See https://en.wikipedia.org/wiki/Affine_transformation#Augmented_matrix for more details.
/// </summary>
/// <param name="pose"></param>
/// <param name="rhs"></param>
/// <returns></returns>
public static Vector3 Multiply(this Pose pose, Vector3 rhs)
{
return pose.rotation * rhs + pose.position;
}
// TODO optimize inv(A)*B?

28
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs


/// Initialize given a root RigidBody.
/// </summary>
/// <param name="rootBody"></param>
public RigidBodyPoseExtractor(Rigidbody rootBody)
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null)
var rbs = rootBody.GetComponentsInChildren <Rigidbody>();
Rigidbody[] rbs;
if (rootGameObject == null)
{
rbs = rootBody.GetComponentsInChildren<Rigidbody>();
}
else
{
rbs = rootGameObject.GetComponentsInChildren<Rigidbody>();
}
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length);
var parentIndices = new int[rbs.Length];

SetParentIndices(parentIndices);
}
/// <summary>
/// Get the pose of the i'th RigidBody.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
/// <inheritdoc/>
protected override Vector3 GetLinearVelocityAt(int index)
{
return m_Bodies[index].velocity;
}
/// <inheritdoc/>
}
}
}

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs


return Pose.identity;
}
protected override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
public void Init(int[] parentIndices)
{
SetParentIndices(parentIndices);

position = translation
};
}
protected override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
}
[Test]

1
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs


using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;

8
com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef


"name": "Unity.ML-Agents.Extensions.EditorTests",
"references": [
"Unity.ML-Agents.Extensions.Editor",
"Unity.ML-Agents.Extensions"
"Unity.ML-Agents.Extensions",
"Unity.ML-Agents"
],
"optionalUnityReferences": [
"TestAssemblies"

],
"excludePlatforms": []
"excludePlatforms": [],
"defineConstraints": [
"UNITY_INCLUDE_TESTS"
]
}

22
com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs


namespace Unity.MLAgents.Tests
{
public class SensorTestHelper
public static class SensorTestHelper
var numExpected = expected.Length;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
Assert.AreEqual(fill, output[0]);
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure ObservationWriter didn't touch anything
Assert.AreEqual(fill, output[0]);
sensor.Write(writer);
Assert.AreEqual(expected, output);
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
}

41
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs


#if UNITY_2020_1_OR_NEWER
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
public class ArticulationBodySensorComponent : SensorComponent
{
public ArticulationBody RootBody;
[SerializeField]
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
public string sensorName;
/// <summary>
/// Creates a PhysicsBodySensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
return new PhysicsBodySensor(RootBody, Settings, sensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
}
}
}
#endif // UNITY_2020_1_OR_NEWER

11
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta


fileFormatVersion: 2
guid: e57a788acd5e049c6aa9642b450ca318
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

93
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
/// </summary>
public class PhysicsBodySensor : ISensor
{
int[] m_Shape;
string m_SensorName;
PoseExtractor m_PoseExtractor;
PhysicsSensorSettings m_Settings;
/// <summary>
/// Construct a new PhysicsBodySensor
/// </summary>
/// <param name="rootBody"></param>
/// <param name="settings"></param>
/// <param name="sensorName"></param>
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
}
#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
}
#endif
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
return numWritten;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
return null;
}
/// <inheritdoc/>
public void Update()
{
if (m_Settings.UseModelSpace)
{
m_PoseExtractor.UpdateModelSpacePoses();
}
if (m_Settings.UseLocalSpace)
{
m_PoseExtractor.UpdateLocalSpacePoses();
}
}
/// <inheritdoc/>
public void Reset() {}
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
/// <inheritdoc/>
public string GetName()
{
return m_SensorName;
}
}
}

11
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta


fileFormatVersion: 2
guid: 254640b3578a24bd2838c1fa39f1011a
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

52
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs


using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Editor component that creates a PhysicsBodySensor for the Agent.
/// </summary>
public class RigidBodySensorComponent : SensorComponent
{
/// <summary>
/// The root Rigidbody of the system.
/// </summary>
public Rigidbody RootBody;
/// <summary>
/// Settings defining what types of observations will be generated.
/// </summary>
[SerializeField]
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
/// <summary>
/// Optional sensor name. This must be unique for each Agent.
/// </summary>
public string sensorName;
/// <summary>
/// Creates a PhysicsBodySensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
}
}
}

11
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta


fileFormatVersion: 2
guid: df0f8be9a37d6486498061e2cbc4cd94
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

63
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs


#if UNITY_2020_1_OR_NEWER
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class ArticulationBodyPoseExtractorTests
{
[TearDown]
public void RemoveGameObjects()
{
var objects = GameObject.FindObjectsOfType<GameObject>();
foreach (var o in objects)
{
UnityEngine.Object.DestroyImmediate(o);
}
}
[Test]
public void TestNullRoot()
{
var poseExtractor = new ArticulationBodyPoseExtractor(null);
// These should be no-ops
poseExtractor.UpdateLocalSpacePoses();
poseExtractor.UpdateModelSpacePoses();
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]
public void TestSingleBody()
{
var go = new GameObject();
var rootArticBody = go.AddComponent<ArticulationBody>();
var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody);
Assert.AreEqual(1, poseExtractor.NumPoses);
}
[Test]
public void TestTwoBodies()
{
// * rootObj
// - rootArticBody
// * leafGameObj
// - leafArticBody
var rootObj = new GameObject();
var rootArticBody = rootObj.AddComponent<ArticulationBody>();
var leafGameObj = new GameObject();
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>();
leafGameObj.transform.SetParent(rootObj.transform);
leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody);
Assert.AreEqual(2, poseExtractor.NumPoses);
Assert.AreEqual(-1, poseExtractor.GetParentIndex(0));
Assert.AreEqual(0, poseExtractor.GetParentIndex(1));
}
}
}
#endif

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta


fileFormatVersion: 2
guid: 934ea08cde59d4356bc41e040d333c3d
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

113
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs


#if UNITY_2020_1_OR_NEWER
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class ArticulationBodySensorTests
{
[Test]
public void TestNullRootBody()
{
var gameObj = new GameObject();
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>();
var sensor = sensorComponent.CreateSensor();
SensorTestHelper.CompareObservation(sensor, new float[0]);
}
[Test]
public void TestSingleBody()
{
var gameObj = new GameObject();
var articulationBody = gameObj.AddComponent<ArticulationBody>();
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>();
sensorComponent.RootBody = articulationBody;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceLinearVelocity = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceRotations = true
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
0f, 0f, 0f, // ModelSpaceLinearVelocity
0f, 0f, 0f, // LocalSpaceTranslations
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]
public void TestBodiesWithJoint()
{
var rootObj = new GameObject();
var rootArticBody = rootObj.AddComponent<ArticulationBody>();
var middleGamObj = new GameObject();
var middleArticBody = middleGamObj.AddComponent<ArticulationBody>();
middleArticBody.AddForce(new Vector3(0f, 1f, 0f));
middleGamObj.transform.SetParent(rootObj.transform);
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f);
middleArticBody.jointType = ArticulationJointType.RevoluteJoint;
var leafGameObj = new GameObject();
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>();
leafGameObj.transform.SetParent(middleGamObj.transform);
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
#if UNITY_2020_2_OR_NEWER
// ArticulationBody.velocity is read-only in 2020.1
rootArticBody.velocity = new Vector3(1f, 0f, 0f);
middleArticBody.velocity = new Vector3(0f, 1f, 0f);
leafArticBody.velocity = new Vector3(0f, 0f, 1f);
#endif
var sensorComponent = rootObj.AddComponent<ArticulationBodySensorComponent>();
sensorComponent.RootBody = rootArticBody;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceTranslations = true,
UseLocalSpaceTranslations = true,
#if UNITY_2020_2_OR_NEWER
UseLocalSpaceLinearVelocity = true
#endif
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
// Model space
0f, 0f, 0f, // Root pos
13.37f, 0f, 0f, // Middle pos
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
// Local space
0f, 0f, 0f, // Root pos
#if UNITY_2020_2_OR_NEWER
0f, 0f, 0f, // Root vel
#endif
13.37f, 0f, 0f, // Attached pos
#if UNITY_2020_2_OR_NEWER
-1f, 1f, 0f, // Attached vel
#endif
4.2f, 0f, 0f, // Leaf pos
#if UNITY_2020_2_OR_NEWER
0f, -1f, 1f // Leaf vel
#endif
};
SensorTestHelper.CompareObservation(sensor, expected);
}
}
}
#endif // #if UNITY_2020_1_OR_NEWER

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta


fileFormatVersion: 2
guid: 0ef757469348342418a68826f51d0783
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

113
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs


using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public static class SensorTestHelper
{
public static void CompareObservation(ISensor sensor, float[] expected)
{
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
}
public class RigidBodySensorTests
{
[Test]
public void TestNullRootBody()
{
var gameObj = new GameObject();
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>();
var sensor = sensorComponent.CreateSensor();
SensorTestHelper.CompareObservation(sensor, new float[0]);
}
[Test]
public void TestSingleRigidbody()
{
var gameObj = new GameObject();
var rootRb = gameObj.AddComponent<Rigidbody>();
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>();
sensorComponent.RootBody = rootRb;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceLinearVelocity = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceRotations = true
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
0f, 0f, 0f, // ModelSpaceLinearVelocity
0f, 0f, 0f, // LocalSpaceTranslations
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]
public void TestBodiesWithJoint()
{
var rootObj = new GameObject();
var rootRb = rootObj.AddComponent<Rigidbody>();
rootRb.velocity = new Vector3(1f, 0f, 0f);
var middleGamObj = new GameObject();
var middleRb = middleGamObj.AddComponent<Rigidbody>();
middleRb.velocity = new Vector3(0f, 1f, 0f);
middleGamObj.transform.SetParent(rootObj.transform);
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f);
var joint = middleGamObj.AddComponent<ConfigurableJoint>();
joint.connectedBody = rootRb;
var leafGameObj = new GameObject();
var leafRb = leafGameObj.AddComponent<Rigidbody>();
leafRb.velocity = new Vector3(0f, 0f, 1f);
leafGameObj.transform.SetParent(middleGamObj.transform);
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
var joint2 = leafGameObj.AddComponent<ConfigurableJoint>();
joint2.connectedBody = middleRb;
var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>();
sensorComponent.RootBody = rootRb;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceTranslations = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceLinearVelocity = true
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
// Model space
0f, 0f, 0f, // Root pos
13.37f, 0f, 0f, // Middle pos
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
// Local space
0f, 0f, 0f, // Root pos
0f, 0f, 0f, // Root vel
13.37f, 0f, 0f, // Attached pos
-1f, 1f, 0f, // Attached vel
4.2f, 0f, 0f, // Leaf pos
0f, -1f, 1f // Leaf vel
};
SensorTestHelper.CompareObservation(sensor, expected);
}
}
}

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta


fileFormatVersion: 2
guid: d8daf5517a7c94bfd9ac7f45f8d1bcd3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

66
com.unity.ml-agents/Runtime/SensorHelper.cs


using UnityEngine;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// Utility methods related to <see cref="ISensor"/> implementations.
/// </summary>
public static class SensorHelper
{
/// <summary>
/// Generates the observations for the provided sensor, and returns true if they equal the
/// expected values. If they are unequal, errorMessage is also set.
/// This should not generally be used in production code. It is only intended for
/// simplifying unit tests.
/// </summary>
/// <param name="sensor"></param>
/// <param name="expected"></param>
/// <param name="errorMessage"></param>
/// <returns></returns>
public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage)
{
var numExpected = expected.Length;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "Error setting output buffer.";
return false;
}
}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
return false;
}
}
sensor.Write(writer);
for (var i = 0; i < output.Length; i++)
{
if (expected[i] != output[i])
{
errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]} Actual: {output[i]} ";
return false;
}
}
errorMessage = null;
return true;
}
}
}

11
com.unity.ml-agents/Runtime/SensorHelper.cs.meta


fileFormatVersion: 2
guid: 7c1189c0af42c46f7b533350d49ad3e7
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存