using System; using Unity.Barracuda; using System.Collections.Generic; using Unity.MLAgents.Actuators; using Unity.MLAgents.Inference; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Policies { /// /// Where to perform inference. /// public enum InferenceDevice { /// /// CPU inference /// CPU = 0, /// /// GPU inference /// GPU = 1 } /// /// The Barracuda Policy uses a Barracuda Model to make decisions at /// every step. It uses a ModelRunner that is shared across all /// Barracuda Policies that use the same model and inference devices. /// internal class BarracudaPolicy : IPolicy { protected ModelRunner m_ModelRunner; ActionBuffers m_LastActionBuffer; int m_AgentId; /// /// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors. /// List m_SensorShapes; SpaceType m_SpaceType; /// public BarracudaPolicy( ActionSpec actionSpec, NNModel model, InferenceDevice inferenceDevice) { var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, actionSpec, inferenceDevice); m_ModelRunner = modelRunner; actionSpec.CheckNotHybrid(); m_SpaceType = actionSpec.NumContinuousActions > 0 ? SpaceType.Continuous : SpaceType.Discrete; } /// public void RequestDecision(AgentInfo info, List sensors) { m_AgentId = info.episodeId; m_ModelRunner?.PutObservations(info, sensors); } /// public ref readonly ActionBuffers DecideAction() { m_ModelRunner?.DecideBatch(); var actions = m_ModelRunner?.GetAction(m_AgentId); if (m_SpaceType == SpaceType.Continuous) { m_LastActionBuffer = new ActionBuffers(actions, Array.Empty()); return ref m_LastActionBuffer; } m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions); return ref m_LastActionBuffer; } public void Dispose() { } } }