using System.Collections.Generic; using UnityEngine; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Extensions.Sensors { /// /// Editor component that creates a PhysicsBodySensor for the Agent. /// public class RigidBodySensorComponent : SensorComponent { /// /// The root Rigidbody of the system. /// public Rigidbody RootBody; /// /// Optional GameObject used to determine the root of the poses. /// public GameObject VirtualRoot; /// /// Settings defining what types of observations will be generated. /// [SerializeField] public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); /// /// Optional sensor name. This must be unique for each Agent. /// [SerializeField] public string sensorName; [SerializeField] [HideInInspector] RigidBodyPoseExtractor m_PoseExtractor; /// /// Creates a PhysicsBodySensor. /// /// public override ISensor CreateSensor() { var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName; return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName); } /// public override int[] GetObservationShape() { if (RootBody == null) { return new[] { 0 }; } var poseExtractor = GetPoseExtractor(); var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); var numJointObservations = 0; foreach (var rb in poseExtractor.GetEnabledRigidbodies()) { var joint = rb.GetComponent(); numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings); } return new[] { numPoseObservations + numJointObservations }; } /// /// Get the DisplayNodes of the hierarchy. /// /// internal IList GetDisplayNodes() { return GetPoseExtractor().GetDisplayNodes(); } /// /// Lazy construction of the PoseExtractor. /// /// RigidBodyPoseExtractor GetPoseExtractor() { if (m_PoseExtractor == null) { ResetPoseExtractor(); } return m_PoseExtractor; } /// /// Reset the pose extractor, trying to keep the enabled state of the corresponding poses the same. /// internal void ResetPoseExtractor() { // Get the current enabled state of each body, so that we can reinitialize with them. Dictionary bodyPosesEnabled = null; if (m_PoseExtractor != null) { bodyPosesEnabled = m_PoseExtractor.GetBodyPosesEnabled(); } m_PoseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot, bodyPosesEnabled); } /// /// Toggle the pose at the given index. /// /// /// internal void SetPoseEnabled(int index, bool enabled) { GetPoseExtractor().SetPoseEnabled(index, enabled); } } }