using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
///
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
///
public class PhysicsBodySensor : ISensor
{
int[] m_Shape;
string m_SensorName;
PoseExtractor m_PoseExtractor;
IJointExtractor[] m_JointExtractors;
PhysicsSensorSettings m_Settings;
///
/// Construct a new PhysicsBodySensor
///
/// The root Rigidbody. This has no Joints on it (but other Joints may connect to it).
/// Optional GameObject used to find Rigidbodies in the hierarchy.
/// Optional GameObject used to determine the root of the poses,
///
///
public PhysicsBodySensor(
Rigidbody rootBody,
GameObject rootGameObject,
GameObject virtualRoot,
PhysicsSensorSettings settings,
string sensorName=null
)
{
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
var numJointExtractorObservations = 0;
var rigidBodies = poseExtractor.Bodies;
if (rigidBodies != null)
{
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
for (var i = 1; i < rigidBodies.Length; i++)
{
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
}
#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
{
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
var numJointExtractorObservations = 0;
var articBodies = poseExtractor.Bodies;
if (articBodies != null)
{
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
for (var i = 1; i < articBodies.Length; i++)
{
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_Shape = new[] { numTransformObservations + numJointExtractorObservations };
}
#endif
///
public int[] GetObservationShape()
{
return m_Shape;
}
///
public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
foreach (var jointExtractor in m_JointExtractors)
{
numWritten += jointExtractor.Write(m_Settings, writer, numWritten);
}
return numWritten;
}
///
public byte[] GetCompressedObservation()
{
return null;
}
///
public void Update()
{
if (m_Settings.UseModelSpace)
{
m_PoseExtractor.UpdateModelSpacePoses();
}
if (m_Settings.UseLocalSpace)
{
m_PoseExtractor.UpdateLocalSpacePoses();
}
}
///
public void Reset() {}
///
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
///
public string GetName()
{
return m_SensorName;
}
}
}