浏览代码

[MLA-1138] joint observations (#4224)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
4f4f2445
共有 15 个文件被更改,包括 424 次插入31 次删除
  1. 2
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  2. 12
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
  3. 53
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  4. 30
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
  5. 18
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  6. 2
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  7. 13
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  8. 34
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
  9. 22
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
  10. 147
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs
  11. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta
  12. 27
      com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs
  13. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta
  14. 62
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs
  15. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta

2
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


var t = go.transform;
return new Pose { rotation = t.rotation, position = t.position };
}
internal ArticulationBody[] Bodies => m_Bodies;
}
}
#endif // UNITY_2020_1_OR_NEWER

12
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs


// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
{
numJointObservations += ArticulationBodyJointExtractor.NumObservations(
poseExtractor.Bodies[i], Settings
);
}
return new[] { numPoseObservations + numJointObservations };
}
}

53
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


string m_SensorName;
PoseExtractor m_PoseExtractor;
IJointExtractor[] m_JointExtractors;
PhysicsSensorSettings m_Settings;
/// <summary>

/// <param name="sensorName"></param>
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
m_PoseExtractor = poseExtractor;
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
var numJointExtractorObservations = 0;
var rigidBodies = poseExtractor.Bodies;
if (rigidBodies != null)
{
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
for (var i = 1; i < rigidBodies.Length; i++)
{
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_PoseExtractor = poseExtractor;
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
var numJointExtractorObservations = 0;
var articBodies = poseExtractor.Bodies;
if (articBodies != null)
{
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
for (var i = 1; i < articBodies.Length; i++)
{
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
}
#endif

public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
foreach (var jointExtractor in m_JointExtractors)
{
numWritten += jointExtractor.Write(m_Settings, writer, numWritten);
}
return numWritten;
}

30
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs


public bool UseLocalSpaceLinearVelocity;
/// <summary>
/// Whether to use joint-specific positions and angles as observations.
/// </summary>
public bool UseJointPositionsAndAngles;
/// <summary>
/// Whether to use the joint forces and torques that are applied by the solver as observations.
/// </summary>
public bool UseJointForces;
/// <summary>
/// Creates a PhysicsSensorSettings with reasonable default values.
/// </summary>
/// <returns></returns>

public bool UseLocalSpace
{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
}
/// <summary>
/// The number of floats needed to represent a given number of transforms.
/// </summary>
/// <param name="numTransforms"></param>
/// <returns></returns>
public int TransformSize(int numTransforms)
{
int obsPerTransform = 0;
obsPerTransform += UseModelSpaceTranslations ? 3 : 0;
obsPerTransform += UseModelSpaceRotations ? 4 : 0;
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;
return numTransforms * obsPerTransform;
}
}

18
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


}
}
/// <summary>
/// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings.
/// </summary>
/// <param name="settings"></param>
/// <returns></returns>
public int GetNumPoseObservations(PhysicsSensorSettings settings)
{
int obsPerPose = 0;
obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0;
obsPerPose += settings.UseModelSpaceRotations ? 4 : 0;
obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0;
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0;
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
return NumPoses * obsPerPose;
}
internal void DrawModelSpace(Vector3 offset)
{

2
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs


var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}
internal Rigidbody[] Bodies => m_Bodies;
}
}

13
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs


// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
{
var body = poseExtractor.Bodies[i];
var joint = body?.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}
}

34
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs


0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
[Test]

var leafArticBody = leafGameObj.AddComponent<ArticulationBody>();
leafGameObj.transform.SetParent(middleGamObj.transform);
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
leafArticBody.jointType = ArticulationJointType.PrismaticJoint;
leafArticBody.linearLockZ = ArticulationDofLock.LimitedMotion;
leafArticBody.zDrive = new ArticulationDrive
{
lowerLimit = -3,
upperLimit = 1
};
#if UNITY_2020_2_OR_NEWER
// ArticulationBody.velocity is read-only in 2020.1

#endif
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings
{
UseJointForces = true,
UseJointPositionsAndAngles = true,
};
sensor = sensorComponent.CreateSensor();
sensor.Update();
expected = new[]
{
// revolute
0f, 1f, // joint1.position (sin and cos)
0f, // joint1.force
// prismatic
0.5f, // joint2.position (interpolate between limits)
0f, // joint2.force
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
}
}

22
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs


0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
[Test]

0f, -1f, 1f // Leaf vel
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings
{
UseJointPositionsAndAngles = true,
UseJointForces = true,
};
sensor = sensorComponent.CreateSensor();
sensor.Update();
expected = new[]
{
0f, 0f, 0f, // joint1.force
0f, 0f, 0f, // joint1.torque
0f, 0f, 0f, // joint2.force
0f, 0f, 0f, // joint2.torque
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
}
}

147
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs


#if UNITY_2020_1_OR_NEWER
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
public class ArticulationBodyJointExtractor : IJointExtractor
{
ArticulationBody m_Body;
public ArticulationBodyJointExtractor(ArticulationBody body)
{
m_Body = body;
}
public int NumObservations(PhysicsSensorSettings settings)
{
return NumObservations(m_Body, settings);
}
public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings)
{
if (body == null || body.isRoot)
{
return 0;
}
var totalCount = 0;
if (settings.UseJointPositionsAndAngles)
{
switch (body.jointType)
{
case ArticulationJointType.RevoluteJoint:
case ArticulationJointType.SphericalJoint:
// Both RevoluteJoint and SphericalJoint have all angular components.
// We use sine and cosine of the angles for the observations.
totalCount += 2 * body.dofCount;
break;
case ArticulationJointType.FixedJoint:
// Since FixedJoint can't moved, there aren't any interesting observations for it.
break;
case ArticulationJointType.PrismaticJoint:
// One linear component
totalCount += body.dofCount;
break;
}
}
if (settings.UseJointForces)
{
totalCount += body.dofCount;
}
return totalCount;
}
public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
{
if (m_Body == null || m_Body.isRoot)
{
return 0;
}
var currentOffset = offset;
// Write joint positions
if (settings.UseJointPositionsAndAngles)
{
switch (m_Body.jointType)
{
case ArticulationJointType.RevoluteJoint:
case ArticulationJointType.SphericalJoint:
// All joint positions are angular
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
{
var jointRotationRads = m_Body.jointPosition[dofIndex];
writer[currentOffset++] = Mathf.Sin(jointRotationRads);
writer[currentOffset++] = Mathf.Cos(jointRotationRads);
}
break;
case ArticulationJointType.FixedJoint:
// No observations
break;
case ArticulationJointType.PrismaticJoint:
writer[currentOffset++] = GetPrismaticValue();
break;
}
}
if (settings.UseJointForces)
{
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++)
{
// take tanh to keep in [-1, 1]
writer[currentOffset++] = (float) System.Math.Tanh(m_Body.jointForce[dofIndex]);
}
}
return currentOffset - offset;
}
float GetPrismaticValue()
{
// Prismatic joints should have at most one free axis.
bool limited = false;
var drive = m_Body.xDrive;
if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.xDrive;
limited = true;
}
else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.yDrive;
limited = true;
}
else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion)
{
drive = m_Body.zDrive;
limited = true;
}
var jointPos = m_Body.jointPosition[0];
if (limited)
{
// If locked, interpolate between the limits.
var upperLimit = drive.upperLimit;
var lowerLimit = drive.lowerLimit;
if (upperLimit <= lowerLimit)
{
// Invalid limits (probably equal), so don't try to lerp
return 0;
}
var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos);
// Convert [0, 1] -> [-1, 1]
var normalized = 2.0f * invLerped - 1.0f;
return normalized;
}
// take tanh() to keep in [-1, 1]
return (float) System.Math.Tanh(jointPos);
}
}
}
#endif

11
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta


fileFormatVersion: 2
guid: 238d15f867b9c4ced9cef331b7420b27
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

27
com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs


using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Interface for generating observations from a physical joint or constraint.
/// </summary>
public interface IJointExtractor
{
/// <summary>
/// Determine the number of observations that would be generated for the particular joint
/// using the provided PhysicsSensorSettings.
/// </summary>
/// <param name="settings"></param>
/// <returns>Number of floats that will be written.</returns>
int NumObservations(PhysicsSensorSettings settings);
/// <summary>
/// Write the observations to the ObservationWriter, starting at the specified offset.
/// </summary>
/// <param name="settings"></param>
/// <param name="writer"></param>
/// <param name="offset"></param>
/// <returns>Number of floats that were written.</returns>
int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset);
}
}

11
com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta


fileFormatVersion: 2
guid: 2d2a01ea194334a4682d5c8cad4a956b
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

62
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs


using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
public class RigidBodyJointExtractor : IJointExtractor
{
Rigidbody m_Body;
Joint m_Joint;
public RigidBodyJointExtractor(Rigidbody body)
{
m_Body = body;
m_Joint = m_Body?.GetComponent<Joint>();
}
public int NumObservations(PhysicsSensorSettings settings)
{
return NumObservations(m_Body, m_Joint, settings);
}
public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSettings settings)
{
if(body == null || joint == null)
{
return 0;
}
var numObservations = 0;
if (settings.UseJointForces)
{
// 3 force and 3 torque values
numObservations += 6;
}
return numObservations;
}
public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset)
{
if (m_Body == null || m_Joint == null)
{
return 0;
}
var currentOffset = offset;
if (settings.UseJointForces)
{
// Take tanh of the forces and torques to ensure they're in [-1, 1]
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.x);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.y);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.z);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.x);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.y);
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.z);
}
return currentOffset - offset;
}
}
}

11
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta


fileFormatVersion: 2
guid: 5014d7ab95c6a44469f447b8a7019746
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存