using Unity.Barracuda; using System.Collections.Generic; using Unity.MLAgents.Inference; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Policies { /// /// Where to perform inference. /// public enum InferenceDevice { /// /// CPU inference /// CPU = 0, /// /// GPU inference /// GPU = 1 } /// /// The Barracuda Policy uses a Barracuda Model to make decisions at /// every step. It uses a ModelRunner that is shared across all /// Barracuda Policies that use the same model and inference devices. /// internal class BarracudaPolicy : IPolicy { protected ModelRunner m_ModelRunner; int m_AgentId; /// /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. /// List m_SensorShapes; private string m_BehaviorName; private BrainParameters m_BrainParameters; /// /// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy, /// and do additional deduplication in the analytics code. /// private bool m_AnalyticsSent; /// public BarracudaPolicy( BrainParameters brainParameters, NNModel model, InferenceDevice inferenceDevice, string behaviorName ) { var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, brainParameters, inferenceDevice); m_ModelRunner = modelRunner; m_BehaviorName = behaviorName; m_BrainParameters = brainParameters; } /// public void RequestDecision(AgentInfo info, List sensors) { if (!m_AnalyticsSent) { m_AnalyticsSent = true; Analytics.InferenceAnalytics.InferenceModelSet( m_ModelRunner.Model, m_BehaviorName, m_ModelRunner.InferenceDevice, sensors, m_BrainParameters ); } m_AgentId = info.episodeId; m_ModelRunner?.PutObservations(info, sensors); } /// public float[] DecideAction() { m_ModelRunner?.DecideBatch(); return m_ModelRunner?.GetAction(m_AgentId); } public void Dispose() { } } }