using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
///
/// Editor component that creates a PhysicsBodySensor for the Agent.
///
public class RigidBodySensorComponent : SensorComponent
{
///
/// The root Rigidbody of the system.
///
public Rigidbody RootBody;
///
/// Optional GameObject used to determine the root of the poses.
///
public GameObject VirtualRoot;
///
/// Settings defining what types of observations will be generated.
///
[SerializeField]
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
///
/// Optional sensor name. This must be unique for each Agent.
///
public string sensorName;
///
/// Creates a PhysicsBodySensor.
///
///
public override ISensor CreateSensor()
{
return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName);
}
///
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
{
var body = poseExtractor.Bodies[i];
var joint = body?.GetComponent();
numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}
}
}