using System.Collections.Generic; #if UNITY_2020_1_OR_NEWER using UnityEngine; #endif using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Extensions.Sensors { /// /// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies. /// public class PhysicsBodySensor : ISensor, IBuiltInSensor { ObservationSpec m_ObservationSpec; string m_SensorName; PoseExtractor m_PoseExtractor; List m_JointExtractors; PhysicsSensorSettings m_Settings; /// /// Construct a new PhysicsBodySensor /// /// /// /// public PhysicsBodySensor( RigidBodyPoseExtractor poseExtractor, PhysicsSensorSettings settings, string sensorName ) { m_PoseExtractor = poseExtractor; m_SensorName = sensorName; m_Settings = settings; var numJointExtractorObservations = 0; m_JointExtractors = new List(poseExtractor.NumEnabledPoses); foreach (var rb in poseExtractor.GetEnabledRigidbodies()) { var jointExtractor = new RigidBodyJointExtractor(rb); numJointExtractorObservations += jointExtractor.NumObservations(settings); m_JointExtractors.Add(jointExtractor); } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); } #if UNITY_2020_1_OR_NEWER public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null) { var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); m_PoseExtractor = poseExtractor; m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; var numJointExtractorObservations = 0; m_JointExtractors = new List(poseExtractor.NumEnabledPoses); foreach (var articBody in poseExtractor.GetEnabledArticulationBodies()) { var jointExtractor = new ArticulationBodyJointExtractor(articBody); numJointExtractorObservations += jointExtractor.NumObservations(settings); m_JointExtractors.Add(jointExtractor); } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); } #endif /// public ObservationSpec GetObservationSpec() { return m_ObservationSpec; } /// public int Write(ObservationWriter writer) { var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); foreach (var jointExtractor in m_JointExtractors) { numWritten += jointExtractor.Write(m_Settings, writer, numWritten); } return numWritten; } /// public byte[] GetCompressedObservation() { return null; } /// public void Update() { if (m_Settings.UseModelSpace) { m_PoseExtractor.UpdateModelSpacePoses(); } if (m_Settings.UseLocalSpace) { m_PoseExtractor.UpdateLocalSpacePoses(); } } /// public void Reset() { } /// public CompressionSpec GetCompressionSpec() { return CompressionSpec.Default(); } /// public string GetName() { return m_SensorName; } /// public BuiltInSensorType GetBuiltInSensorType() { return BuiltInSensorType.PhysicsBodySensor; } } }