using System.Collections.Generic; using UnityEngine; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Extensions.Sensors { /// /// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies. /// public class PhysicsBodySensor : ISensor { int[] m_Shape; string m_SensorName; PoseExtractor m_PoseExtractor; List m_JointExtractors; PhysicsSensorSettings m_Settings; /// /// Construct a new PhysicsBodySensor /// /// /// /// public PhysicsBodySensor( RigidBodyPoseExtractor poseExtractor, PhysicsSensorSettings settings, string sensorName ) { m_PoseExtractor = poseExtractor; m_SensorName = sensorName; m_Settings = settings; var numJointExtractorObservations = 0; m_JointExtractors = new List(poseExtractor.NumEnabledPoses); foreach(var rb in poseExtractor.GetEnabledRigidbodies()) { var jointExtractor = new RigidBodyJointExtractor(rb); numJointExtractorObservations += jointExtractor.NumObservations(settings); m_JointExtractors.Add(jointExtractor); } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } #if UNITY_2020_1_OR_NEWER public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) { var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); m_PoseExtractor = poseExtractor; m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; m_Settings = settings; var numJointExtractorObservations = 0; m_JointExtractors = new List(poseExtractor.NumEnabledPoses); foreach(var articBody in poseExtractor.GetEnabledArticulationBodies()) { var jointExtractor = new ArticulationBodyJointExtractor(articBody); numJointExtractorObservations += jointExtractor.NumObservations(settings); m_JointExtractors.Add(jointExtractor); } var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; } #endif /// public int[] GetObservationShape() { return m_Shape; } /// public int Write(ObservationWriter writer) { var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); foreach (var jointExtractor in m_JointExtractors) { numWritten += jointExtractor.Write(m_Settings, writer, numWritten); } return numWritten; } /// public byte[] GetCompressedObservation() { return null; } /// public void Update() { if (m_Settings.UseModelSpace) { m_PoseExtractor.UpdateModelSpacePoses(); } if (m_Settings.UseLocalSpace) { m_PoseExtractor.UpdateLocalSpacePoses(); } } /// public void Reset() {} /// public SensorCompressionType GetCompressionType() { return SensorCompressionType.None; } /// public string GetName() { return m_SensorName; } } }