using Unity.Barracuda; using System.Collections.Generic; using Unity.MLAgents.Inference; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Policies { /// /// Where to perform inference. /// public enum InferenceDevice { /// /// CPU inference /// CPU = 0, /// /// GPU inference /// GPU = 1 } /// /// The Barracuda Policy uses a Barracuda Model to make decisions at /// every step. It uses a ModelRunner that is shared across all /// Barracuda Policies that use the same model and inference devices. /// internal class BarracudaPolicy : IPolicy { protected ModelRunner m_ModelRunner; int m_AgentId; /// /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. /// List m_SensorShapes; /// public BarracudaPolicy( BrainParameters brainParameters, NNModel model, InferenceDevice inferenceDevice) { var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, brainParameters, inferenceDevice); m_ModelRunner = modelRunner; } /// public void RequestDecision(AgentInfo info, List sensors) { m_AgentId = info.episodeId; m_ModelRunner?.PutObservations(info, sensors); } /// public float[] DecideAction() { m_ModelRunner?.DecideBatch(); return m_ModelRunner?.GetAction(m_AgentId); } public void Dispose() { } } }