using Unity.Barracuda; using System.Collections.Generic; using System.Diagnostics; using Unity.MLAgents.Actuators; using Unity.MLAgents.Inference; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Policies { /// /// Where to perform inference. /// public enum InferenceDevice { /// /// CPU inference. Corresponds to in WorkerFactory.Type.CSharp Barracuda. /// Burst is recommended instead; this is kept for legacy compatibility. /// CPU = 0, /// /// GPU inference. Corresponds to WorkerFactory.Type.ComputePrecompiled in Barracuda. /// GPU = 1, /// /// CPU inference using Burst. Corresponds to WorkerFactory.Type.CSharpBurst in Barracuda. /// Burst = 2, } /// /// The Barracuda Policy uses a Barracuda Model to make decisions at /// every step. It uses a ModelRunner that is shared across all /// Barracuda Policies that use the same model and inference devices. /// internal class BarracudaPolicy : IPolicy { protected ModelRunner m_ModelRunner; ActionBuffers m_LastActionBuffer; int m_AgentId; /// /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. /// List m_SensorShapes; ActionSpec m_ActionSpec; private string m_BehaviorName; /// /// List of actuators, only used for analytics /// private IList m_Actuators; /// /// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy, /// and do additional deduplication in the analytics code. /// private bool m_AnalyticsSent; /// /// Instantiate a BarracudaPolicy with the necessary objects for it to run. /// /// The action spec of the behavior. /// The actuators used for this behavior. /// The Neural Network to use. /// Which device Barracuda will run on. /// The name of the behavior. public BarracudaPolicy( ActionSpec actionSpec, IList actuators, NNModel model, InferenceDevice inferenceDevice, string behaviorName ) { var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice); m_ModelRunner = modelRunner; m_BehaviorName = behaviorName; m_ActionSpec = actionSpec; m_Actuators = actuators; } /// public void RequestDecision(AgentInfo info, List sensors) { SendAnalytics(sensors); m_AgentId = info.episodeId; m_ModelRunner?.PutObservations(info, sensors); } [Conditional("MLA_UNITY_ANALYTICS_MODULE")] void SendAnalytics(IList sensors) { if (!m_AnalyticsSent) { m_AnalyticsSent = true; Analytics.InferenceAnalytics.InferenceModelSet( m_ModelRunner.Model, m_BehaviorName, m_ModelRunner.InferenceDevice, sensors, m_ActionSpec, m_Actuators ); } } /// public ref readonly ActionBuffers DecideAction() { if (m_ModelRunner == null) { m_LastActionBuffer = ActionBuffers.Empty; } else { m_ModelRunner?.DecideBatch(); m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId); } return ref m_LastActionBuffer; } public void Dispose() { } } }