using System.Collections.Generic; using Barracuda; using UnityEngine.Profiling; namespace MLAgents.InferenceBrain { public class ModelRunner : IBatchedDecisionMaker { private List m_Agents = new List(); private ITensorAllocator m_TensorAllocator; private TensorGenerator m_TensorGenerator; private TensorApplier m_TensorApplier; private NNModel m_Model; private InferenceDevice m_InferenceDevice; private IWorker m_Engine; private bool m_Verbose = false; private string[] m_OutputNames; private IReadOnlyList m_InferenceInputs; private IReadOnlyList m_InferenceOutputs; private bool m_visualObservationsInitialized = false; /// /// Initializes the Brain with the Model that it will use when selecting actions for /// the agents /// /// The Barracuda model to load /// The parameters of the Brain used to generate the /// placeholder tensors /// Inference execution device. CPU is the fastest /// option for most of ML Agents models. /// The seed that will be used to initialize the RandomNormal /// and Multinomial objects used when running inference. /// Throws an error when the model is null /// public ModelRunner( NNModel model, BrainParameters brainParameters, InferenceDevice inferenceDevice = InferenceDevice.CPU, int seed = 0) { Model barracudaModel; m_Model = model; m_InferenceDevice = inferenceDevice; m_TensorAllocator = new TensorCachingAllocator(); if (model != null) { #if BARRACUDA_VERBOSE m_Verbose = true; #endif D.logEnabled = m_Verbose; barracudaModel = ModelLoader.Load(model.Value); var executionDevice = inferenceDevice == InferenceDevice.GPU ? BarracudaWorkerFactory.Type.ComputePrecompiled : BarracudaWorkerFactory.Type.CSharp; m_Engine = BarracudaWorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); } else { barracudaModel = null; m_Engine = null; } m_InferenceInputs = BarracudaModelParamLoader.GetInputTensors(barracudaModel); m_OutputNames = BarracudaModelParamLoader.GetOutputNames(barracudaModel); m_TensorGenerator = new TensorGenerator(brainParameters, seed, m_TensorAllocator, barracudaModel); m_TensorApplier = new TensorApplier(brainParameters, seed, m_TensorAllocator, barracudaModel); } private static Dictionary PrepareBarracudaInputs(IEnumerable infInputs) { var inputs = new Dictionary(); foreach (var inp in infInputs) { inputs[inp.name] = inp.data; } return inputs; } public void Dispose() { if (m_Engine != null) m_Engine.Dispose(); m_TensorAllocator?.Reset(false); } private List FetchBarracudaOutputs(string[] names) { var outputs = new List(); foreach (var n in names) { var output = m_Engine.Peek(n); outputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n)); } return outputs; } public void PutObservations(string key, Agent agent) { m_Agents.Add(agent); } public void DecideBatch() { var currentBatchSize = m_Agents.Count; if (currentBatchSize == 0) { return; } if (!m_visualObservationsInitialized) { // Just grab the first agent in the collection (any will suffice, really). // We check for an empty Collection above, so this will always return successfully. var firstAgent = m_Agents[0]; m_TensorGenerator.InitializeVisualObservations(firstAgent, m_TensorAllocator); m_visualObservationsInitialized = true; } Profiler.BeginSample("LearningBrain.DecideAction"); Profiler.BeginSample($"MLAgents.{m_Model.name}.GenerateTensors"); // Prepare the input tensors to be feed into the engine m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Agents); Profiler.EndSample(); Profiler.BeginSample($"MLAgents.{m_Model.name}.PrepareBarracudaInputs"); var inputs = PrepareBarracudaInputs(m_InferenceInputs); Profiler.EndSample(); // Execute the Model Profiler.BeginSample($"MLAgents.{m_Model.name}.ExecuteGraph"); m_Engine.Execute(inputs); Profiler.EndSample(); Profiler.BeginSample($"MLAgents.{m_Model.name}.FetchBarracudaOutputs"); m_InferenceOutputs = FetchBarracudaOutputs(m_OutputNames); Profiler.EndSample(); Profiler.BeginSample($"MLAgents.{m_Model.name}.ApplyTensors"); // Update the outputs m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_Agents); Profiler.EndSample(); Profiler.EndSample(); m_Agents.Clear(); } public bool HasModel(NNModel other, InferenceDevice otherInferenceDevice) { return m_Model == other && m_InferenceDevice == otherInferenceDevice; } } }