浏览代码

[MLA-1135] Physics sensors - optional reference body (#4276)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
3a0d6e54
共有 19 个文件被更改,包括 1004 次插入1597 次删除
  1. 35
      Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
  2. 5
      Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
  3. 12
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  4. 1001
      Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn
  5. 1001
      Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn
  6. 2
      com.unity.ml-agents.extensions/README.md
  7. 6
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  8. 14
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  9. 31
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
  10. 211
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  11. 77
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  12. 9
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  13. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
  14. 88
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
  15. 58
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
  16. 25
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
  17. 1
      com.unity.ml-agents/Runtime/AssemblyInfo.cs
  18. 3
      com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
  19. 11
      com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta

35
Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab


- component: {fileID: 4845971001715176662}
- component: {fileID: 4845971001715176663}
- component: {fileID: 4845971001715176660}
- component: {fileID: 4622120667686875944}
m_Layer: 0
m_Name: Crawler
m_TagString: Untagged

m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
VectorObservationSize: 138
VectorObservationSize: 21
NumStackedVectorObservations: 1
VectorActionSize: 14000000
VectorActionDescriptions: []

m_Name:
m_EditorClassIdentifier:
debugCommandLineOverride:
--- !u!114 &4622120667686875944
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 4845971001715176661}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: df0f8be9a37d6486498061e2cbc4cd94, type: 3}
m_Name:
m_EditorClassIdentifier:
RootBody: {fileID: 4845971001588102145}
VirtualRoot: {fileID: 2270141184585723037}
Settings:
UseModelSpaceTranslations: 1
UseModelSpaceRotations: 1
UseLocalSpaceTranslations: 0
UseLocalSpaceRotations: 1
UseModelSpaceLinearVelocity: 1
UseLocalSpaceLinearVelocity: 0
UseJointPositionsAndAngles: 0
UseJointForces: 0
sensorName:
--- !u!1 &4845971001730692034
GameObject:
m_ObjectHideFlags: 0

objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 72f745913c5a34df5aaadd5c1f0024cb, type: 3}
--- !u!1 &2270141184585723037 stripped
GameObject:
m_CorrespondingSourceObject: {fileID: 2591864627249999519, guid: 72f745913c5a34df5aaadd5c1f0024cb,
type: 3}
m_PrefabInstance: {fileID: 4357529801223143938}
m_PrefabAsset: {fileID: 0}
--- !u!4 &2270141184585723026 stripped
Transform:
m_CorrespondingSourceObject: {fileID: 2591864627249999504, guid: 72f745913c5a34df5aaadd5c1f0024cb,

type: 3}
m_PrefabInstance: {fileID: 4357529801223143938}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 0}
m_GameObject: {fileID: 2270141184585723037}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 771e78c5e980e440e8cd19716b55075f, type: 3}

5
Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab


propertyPath: targetToLookAt
value:
objectReference: {fileID: 2673081981996998229}
- target: {fileID: 4622120667686875944, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
type: 3}
propertyPath: Settings.UseLocalSpaceLinearVelocity
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4845971000000621469, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
type: 3}
propertyPath: m_ConnectedAnchor.x

12
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


//GROUND CHECK
sensor.AddObservation(bp.groundContact.touchingGround); // Is this bp touching the ground
//Get velocities in the context of our orientation cube's space
//Note: You can get these velocities in world space as well but it may not train as well.
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.velocity));
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
//Get position relative to hips in the context of our orientation cube's space
sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.position - body.position));
sensor.AddObservation(bp.rb.transform.localRotation);
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}

/// </summary>
public override void CollectObservations(VectorSensor sensor)
{
//Add body rotation delta relative to orientation cube
sensor.AddObservation(Quaternion.FromToRotation(body.forward, orientationCube.transform.forward));
//Add pos of target relative to orientation cube
sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position));

1001
Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn
文件差异内容过多而无法显示
查看文件

1001
Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn
文件差异内容过多而无法显示
查看文件

2
com.unity.ml-agents.extensions/README.md


# ML-Agents Extensions
This is a source-only package for new features based on ML-Agents.
More details coming soon.

6
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


parentIndices[i] = bodyToIndex[parentArticBody];
}
SetParentIndices(parentIndices);
Setup(parentIndices);
protected override Vector3 GetLinearVelocityAt(int index)
protected internal override Vector3 GetLinearVelocityAt(int index)
protected override Pose GetPoseAt(int index)
protected internal override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
var go = body.gameObject;

14
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


/// <summary>
/// Construct a new PhysicsBodySensor
/// </summary>
/// <param name="rootBody"></param>
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
public PhysicsBodySensor(
Rigidbody rootBody,
GameObject rootGameObject,
GameObject virtualRoot,
PhysicsSensorSettings settings,
string sensorName=null
)
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;

31
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs


using System;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors

var offset = baseOffset;
if (settings.UseModelSpace)
{
var poses = poseExtractor.ModelSpacePoses;
var vels = poseExtractor.ModelSpaceVelocities;
for(var i=0; i<poseExtractor.NumPoses; i++)
foreach (var pose in poseExtractor.GetEnabledModelSpacePoses())
var pose = poses[i];
if(settings.UseModelSpaceTranslations)
if (settings.UseModelSpaceTranslations)
}
foreach(var vel in poseExtractor.GetEnabledModelSpaceVelocities())
{
writer.Add(vels[i], offset);
writer.Add(vel, offset);
offset += 3;
}
}

{
var poses = poseExtractor.LocalSpacePoses;
var vels = poseExtractor.LocalSpaceVelocities;
for(var i=0; i<poseExtractor.NumPoses; i++)
foreach (var pose in poseExtractor.GetEnabledLocalSpacePoses())
var pose = poses[i];
if(settings.UseLocalSpaceTranslations)
if (settings.UseLocalSpaceTranslations)
}
foreach(var vel in poseExtractor.GetEnabledLocalSpaceVelocities())
{
writer.Add(vels[i], offset);
writer.Add(vel, offset);
offset += 3;
}
}

211
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


Vector3[] m_ModelSpaceLinearVelocities;
Vector3[] m_LocalSpaceLinearVelocities;
bool[] m_PoseEnabled;
/// Read access to the model space transforms.
/// Read iterator for the enabled model space transforms.
public IList<Pose> ModelSpacePoses
public IEnumerable<Pose> GetEnabledModelSpacePoses()
get { return m_ModelSpacePoses; }
if (m_ModelSpacePoses == null)
{
yield break;
}
for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
if (m_PoseEnabled[i])
{
yield return m_ModelSpacePoses[i];
}
}
/// Read access to the local space transforms.
/// Read iterator for the enabled local space transforms.
public IList<Pose> LocalSpacePoses
public IEnumerable<Pose> GetEnabledLocalSpacePoses()
get { return m_LocalSpacePoses; }
if (m_LocalSpacePoses == null)
{
yield break;
}
for (var i = 0; i < m_LocalSpacePoses.Length; i++)
{
if (m_PoseEnabled[i])
{
yield return m_LocalSpacePoses[i];
}
}
/// Read access to the model space linear velocities.
/// Read iterator for the enabled model space linear velocities.
public IList<Vector3> ModelSpaceVelocities
public IEnumerable<Vector3> GetEnabledModelSpaceVelocities()
get { return m_ModelSpaceLinearVelocities; }
if (m_ModelSpaceLinearVelocities == null)
{
yield break;
}
for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++)
{
if (m_PoseEnabled[i])
{
yield return m_ModelSpaceLinearVelocities[i];
}
}
/// Read access to the local space linear velocities.
/// Read iterator for the enabled local space linear velocities.
/// </summary>
public IEnumerable<Vector3> GetEnabledLocalSpaceVelocities()
{
if (m_LocalSpaceLinearVelocities == null)
{
yield break;
}
for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++)
{
if (m_PoseEnabled[i])
{
yield return m_LocalSpaceLinearVelocities[i];
}
}
}
/// <summary>
/// Number of enabled poses in the hierarchy (read-only).
public IList<Vector3> LocalSpaceVelocities
public int NumEnabledPoses
get { return m_LocalSpaceLinearVelocities; }
get
{
if (m_PoseEnabled == null)
{
return 0;
}
var numEnabled = 0;
for (var i = 0; i < m_PoseEnabled.Length; i++)
{
numEnabled += m_PoseEnabled[i] ? 1 : 0;
}
return numEnabled;
}
/// Number of poses in the hierarchy (read-only).
/// Number of total poses in the hierarchy (read-only).
get { return m_ModelSpacePoses?.Length ?? 0; }
get { return m_ModelSpacePoses?.Length ?? 0; }
}
/// <summary>

}
return m_ParentIndices[index];
}
/// <summary>
/// Set whether the pose at the given index is enabled or disabled for observations.
/// </summary>
/// <param name="index"></param>
/// <param name="val"></param>
public void SetPoseEnabled(int index, bool val)
{
m_PoseEnabled[index] = val;
}
/// <summary>

/// <param name="parentIndices"></param>
protected void SetParentIndices(int[] parentIndices)
protected void Setup(int[] parentIndices)
#if DEBUG
if (parentIndices[0] != -1)
{
throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}");
}
#endif
var numTransforms = parentIndices.Length;
m_ModelSpacePoses = new Pose[numTransforms];
m_LocalSpacePoses = new Pose[numTransforms];
var numPoses = parentIndices.Length;
m_ModelSpacePoses = new Pose[numPoses];
m_LocalSpacePoses = new Pose[numPoses];
m_ModelSpaceLinearVelocities = new Vector3[numPoses];
m_LocalSpaceLinearVelocities = new Vector3[numPoses];
m_ModelSpaceLinearVelocities = new Vector3[numTransforms];
m_LocalSpaceLinearVelocities = new Vector3[numTransforms];
m_PoseEnabled = new bool[numPoses];
// All poses are enabled by default. Generally we'll want to disable the root though.
for (var i = 0; i < numPoses; i++)
{
m_PoseEnabled[i] = true;
}
}
/// <summary>

/// <returns></returns>
protected abstract Pose GetPoseAt(int index);
protected internal abstract Pose GetPoseAt(int index);
/// <summary>
/// Return the world space linear velocity of the i'th object.

protected abstract Vector3 GetLinearVelocityAt(int index);
protected internal abstract Vector3 GetLinearVelocityAt(int index);
/// <summary>

{
if (m_ModelSpacePoses == null)
using (TimerStack.Instance.Scoped("UpdateModelSpacePoses"))
return;
}
if (m_ModelSpacePoses == null)
{
return;
}
var rootWorldTransform = GetPoseAt(0);
var worldToModel = rootWorldTransform.Inverse();
var rootLinearVel = GetLinearVelocityAt(0);
var rootWorldTransform = GetPoseAt(0);
var worldToModel = rootWorldTransform.Inverse();
var rootLinearVel = GetLinearVelocityAt(0);
for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
var currentWorldSpacePose = GetPoseAt(i);
var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
m_ModelSpacePoses[i] = currentModelSpacePose;
for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
var currentWorldSpacePose = GetPoseAt(i);
var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
m_ModelSpacePoses[i] = currentModelSpacePose;
var currentBodyLinearVel = GetLinearVelocityAt(i);
var relativeVelocity = currentBodyLinearVel - rootLinearVel;
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
var currentBodyLinearVel = GetLinearVelocityAt(i);
var relativeVelocity = currentBodyLinearVel - rootLinearVel;
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
}
}
}

public void UpdateLocalSpacePoses()
{
if (m_LocalSpacePoses == null)
{
return;
}
for (var i = 0; i < m_LocalSpacePoses.Length; i++)
using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses"))
if (m_ParentIndices[i] != -1)
if (m_LocalSpacePoses == null)
var parentTransform = GetPoseAt(m_ParentIndices[i]);
// This is slightly inefficient, since for a body with multiple children, we'll end up inverting
// the transform multiple times. Might be able to trade space for perf here.
var invParent = parentTransform.Inverse();
var currentTransform = GetPoseAt(i);
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
var currentLinearVel = GetLinearVelocityAt(i);
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
return;
else
for (var i = 0; i < m_LocalSpacePoses.Length; i++)
m_LocalSpacePoses[i] = Pose.identity;
m_LocalSpaceLinearVelocities[i] = Vector3.zero;
if (m_ParentIndices[i] != -1)
{
var parentTransform = GetPoseAt(m_ParentIndices[i]);
// This is slightly inefficient, since for a body with multiple children, we'll end up inverting
// the transform multiple times. Might be able to trade space for perf here.
var invParent = parentTransform.Inverse();
var currentTransform = GetPoseAt(i);
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
var currentLinearVel = GetLinearVelocityAt(i);
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
}
else
{
m_LocalSpacePoses[i] = Pose.identity;
m_LocalSpaceLinearVelocities[i] = Vector3.zero;
}
}
}
}

obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
return NumPoses * obsPerPose;
return NumEnabledPoses * obsPerPose;
}
internal void DrawModelSpace(Vector3 offset)

77
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs


Rigidbody[] m_Bodies;
/// <summary>
/// Optional game object used to determine the root of the poses, separate from the actual Rigidbodies
/// in the hierarchy. For locomotion
/// </summary>
GameObject m_VirtualRoot;
/// <summary>
/// <param name="rootBody"></param>
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null)
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
/// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
/// a stabilized refernece frame, which can improve learning.</param>
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, GameObject virtualRoot = null)
{
if (rootBody == null)
{

{
rbs = rootGameObject.GetComponentsInChildren<Rigidbody>();
}
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length);
var parentIndices = new int[rbs.Length];
if (rbs == null || rbs.Length == 0)
{
Debug.Log("No rigid bodies found!");
return;
}
if (rbs[0] != rootBody)
if (rbs[0] != rootBody)
// Adjust the array if we have a virtual root.
// This will be at index 0, and the "real" root will be parented to it.
if (virtualRoot != null)
{
var extendedRbs = new Rigidbody[rbs.Length + 1];
for (var i = 0; i < rbs.Length; i++)
{
extendedRbs[i + 1] = rbs[i];
}
rbs = extendedRbs;
}
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length);
var parentIndices = new int[rbs.Length];
parentIndices[0] = -1;
bodyToIndex[rbs[i]] = i;
if(rbs[i] != null)
{
bodyToIndex[rbs[i]] = i;
}
}
var joints = rootBody.GetComponentsInChildren <Joint>();

parentIndices[childIndex] = parentIndex;
}
if (virtualRoot != null)
{
// Make sure the original root treats the virtual root as its parent.
parentIndices[1] = 0;
m_VirtualRoot = virtualRoot;
}
SetParentIndices(parentIndices);
Setup(parentIndices);
// By default, ignore the root
SetPoseEnabled(0, false);
protected override Vector3 GetLinearVelocityAt(int index)
protected internal override Vector3 GetLinearVelocityAt(int index)
if (index == 0 && m_VirtualRoot != null)
{
// No velocity on the virtual root
return Vector3.zero;
}
protected override Pose GetPoseAt(int index)
protected internal override Pose GetPoseAt(int index)
if (index == 0 && m_VirtualRoot != null)
{
// Use the GameObject's world transform
return new Pose
{
rotation = m_VirtualRoot.transform.rotation,
position = m_VirtualRoot.transform.position
};
}
var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}

9
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs


public Rigidbody RootBody;
/// <summary>
/// Optional GameObject used to determine the root of the poses.
/// </summary>
public GameObject VirtualRoot;
/// <summary>
/// Settings defining what types of observations will be generated.
/// </summary>
[SerializeField]

/// <returns></returns>
public override ISensor CreateSensor()
{
return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName);
return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName);
}
/// <inheritdoc/>

// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs


// Local space
0f, 0f, 0f, // Root pos
13.37f, 0f, 0f, // Attached pos
4.2f, 0f, 0f, // Leaf pos
#endif
13.37f, 0f, 0f, // Attached pos
#if UNITY_2020_2_OR_NEWER
#endif
4.2f, 0f, 0f, // Leaf pos
#if UNITY_2020_2_OR_NEWER
0f, -1f, 1f // Leaf vel
#endif
};

88
com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs


{
class UselessPoseExtractor : PoseExtractor
{
protected override Pose GetPoseAt(int index)
protected internal override Pose GetPoseAt(int index)
protected override Vector3 GetLinearVelocityAt(int index)
protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}

SetParentIndices(parentIndices);
Setup(parentIndices);
}
}

{
parents[i] = i - 1;
}
SetParentIndices(parents);
Setup(parents);
protected override Pose GetPoseAt(int index)
protected internal override Pose GetPoseAt(int index)
{
var rotation = Quaternion.identity;
var translation = offset + new Vector3(index, index, index);

};
}
protected override Vector3 GetLinearVelocityAt(int index)
protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}

chain.UpdateModelSpacePoses();
chain.UpdateLocalSpacePoses();
// Root transforms are currently always the identity.
Assert.IsTrue(chain.ModelSpacePoses[0] == Pose.identity);
Assert.IsTrue(chain.LocalSpacePoses[0] == Pose.identity);
// Check the non-root transforms
for (var i = 1; i < size; i++)
var modelPoseIndex = 0;
foreach (var modelSpace in chain.GetEnabledModelSpacePoses())
var modelSpace = chain.ModelSpacePoses[i];
var expectedModelTranslation = new Vector3(i, i, i);
Assert.IsTrue(expectedModelTranslation == modelSpace.position);
if (modelPoseIndex == 0)
{
// Root transforms are currently always the identity.
Assert.IsTrue(modelSpace == Pose.identity);
}
else
{
var expectedModelTranslation = new Vector3(modelPoseIndex, modelPoseIndex, modelPoseIndex);
Assert.IsTrue(expectedModelTranslation == modelSpace.position);
var localSpace = chain.LocalSpacePoses[i];
var expectedLocalTranslation = new Vector3(1, 1, 1);
Assert.IsTrue(expectedLocalTranslation == localSpace.position);
}
modelPoseIndex++;
Assert.AreEqual(size, modelPoseIndex);
var localPoseIndex = 0;
foreach (var localSpace in chain.GetEnabledLocalSpacePoses())
{
if (localPoseIndex == 0)
{
// Root transforms are currently always the identity.
Assert.IsTrue(localSpace == Pose.identity);
}
else
{
var expectedLocalTranslation = new Vector3(1, 1, 1);
Assert.IsTrue(expectedLocalTranslation == localSpace.position, $"{expectedLocalTranslation} != {localSpace.position}");
}
localPoseIndex++;
}
Assert.AreEqual(size, localPoseIndex);
class BadPoseExtractor : PoseExtractor
{
public BadPoseExtractor()
{
var size = 2;
var parents = new int[size];
// Parents are intentionally invalid - expect -1 at root
for (var i = 0; i < size; i++)
{
parents[i] = i;
}
Setup(parents);
}
protected internal override Pose GetPoseAt(int index)
{
return Pose.identity;
}
protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
}
[Test]
public void TestExpectedRoot()
{
Assert.Throws<UnityAgentsException>(() =>
{
var bad = new BadPoseExtractor();
});
}
}
public class PoseExtensionTests

58
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs


using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
using UnityEditor;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{

var poseExtractor = new RigidBodyPoseExtractor(rb1);
Assert.AreEqual(2, poseExtractor.NumPoses);
rb1.position = new Vector3(1, 0, 0);
rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
rb1.velocity = new Vector3(2, 0, 0);
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0));
}
[Test]
public void TestTwoBodiesVirtualRoot()
{
// * virtualRoot
// * rootObj
// - rb1
// * go2
// - rb2
// - joint
var virtualRoot = new GameObject("I am vroot");
var rootObj = new GameObject();
var rb1 = rootObj.AddComponent<Rigidbody>();
var go2 = new GameObject();
var rb2 = go2.AddComponent<Rigidbody>();
go2.transform.SetParent(rootObj.transform);
var joint = go2.AddComponent<ConfigurableJoint>();
joint.connectedBody = rb1;
var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot);
Assert.AreEqual(3, poseExtractor.NumPoses);
// "body" 0 has no parent
Assert.AreEqual(-1, poseExtractor.GetParentIndex(0));
// body 1 has parent 0
Assert.AreEqual(0, poseExtractor.GetParentIndex(1));
var virtualRootPos = new Vector3(0,2,0);
var virtualRootRot = Quaternion.Euler(0, 42, 0);
virtualRoot.transform.position = virtualRootPos;
virtualRoot.transform.rotation = virtualRootRot;
Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position);
Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation);
Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0));
// Same as above test, but using index 1
rb1.position = new Vector3(1, 0, 0);
rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
rb1.velocity = new Vector3(2, 0, 0);
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1));
}
}
}

25
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs


var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
0f, 0f, 0f, // ModelSpaceLinearVelocity
0f, 0f, 0f, // LocalSpaceTranslations
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
// The root body is ignored since it always generates identity values
// and there are no other bodies to generate observations.
var expected = new float[0];
SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]

var joint2 = leafGameObj.AddComponent<ConfigurableJoint>();
joint2.connectedBody = middleRb;
var virtualRoot = new GameObject();
var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>();
sensorComponent.RootBody = rootRb;

UseLocalSpaceTranslations = true,
UseLocalSpaceLinearVelocity = true
};
sensorComponent.VirtualRoot = virtualRoot;
// Note that the VirtualRoot is ignored from the observations
var expected = new[]
{
// Model space

// Local space
0f, 0f, 0f, // Root pos
0f, 0f, 0f, // Root vel
-1f, 1f, 0f, // Attached vel
4.2f, 0f, 0f, // Leaf pos
4.2f, 0f, 0f, // Leaf pos
1f, 0f, 0f, // Root vel (relative to virtual root)
-1f, 1f, 0f, // Attached vel
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings

1
com.unity.ml-agents/Runtime/AssemblyInfo.cs


[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions")]

3
com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs


using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]

11
com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta


fileFormatVersion: 2
guid: 48c8790647c3345e19c57d6c21065112
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存