// ModelRunner for C# training. using System.Collections.Generic; using Unity.Barracuda; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Inference; using Unity.MLAgents.Policies; using Unity.MLAgents.Sensors; using UnityEngine; using Unity.MLAgents.Inference.Utils; namespace Unity.MLAgents { internal class TrainingModelRunner { List m_Infos = new List(); Dictionary m_LastActionsReceived = new Dictionary(); List m_OrderedAgentsRequestingDecisions = new List(); TensorProxy m_TrainingState; ITensorAllocator m_TensorAllocator; TensorGenerator m_TensorGenerator; TrainingTensorGenerator m_TrainingTensorGenerator; TrainingForwardTensorApplier m_TensorApplier; Model m_Model; IWorker m_Engine; bool m_Verbose = false; string[] m_OutputNames; IReadOnlyList m_TrainingInputs; List m_TrainingOutputs; Dictionary m_InputsByName; Dictionary> m_Memories = new Dictionary>(); bool m_ObservationsInitialized; bool m_TrainingObservationsInitialized; ReplayBuffer m_Buffer; /// /// Initializes the Brain with the Model that it will use when selecting actions for /// the agents /// /// The Barracuda model to load /// Description of the actions for the Agent. /// Inference execution device. CPU is the fastest /// option for most of ML Agents models. /// The seed that will be used to initialize the RandomNormal /// and Multinomial objects used when running inference. /// Throws an error when the model is null /// public TrainingModelRunner( ActionSpec actionSpec, NNModel model, ReplayBuffer buffer, TrainerConfig config, int seed = 0) { Model barracudaModel; m_TensorAllocator = new TensorCachingAllocator(); // barracudaModel = Barracuda.SomeModelBuilder.CreateModel(); barracudaModel = ModelLoader.Load(model); m_Model = barracudaModel; WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharpBurst; m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); m_TrainingInputs = barracudaModel.GetTrainingInputTensors(); m_OutputNames = barracudaModel.GetOutputNames(); InitializeTrainingState(barracudaModel); m_TensorGenerator = new TensorGenerator( seed, m_TensorAllocator, m_Memories, barracudaModel); m_TrainingTensorGenerator = new TrainingTensorGenerator( seed, m_TensorAllocator, config.learningRate, config.gamma, barracudaModel); m_TensorApplier = new TrainingForwardTensorApplier( actionSpec, seed, m_TensorAllocator, barracudaModel); m_InputsByName = new Dictionary(); m_TrainingOutputs = new List(); m_Buffer = buffer; } void InitializeTrainingState(Model barracudaModel) { m_TrainingState = new TensorProxy { data = barracudaModel.GetTensorByName(TensorNames.InitialTrainingState) }; } void PrepareBarracudaInputs(IReadOnlyList infInputs) { m_InputsByName.Clear(); for (var i = 0; i < infInputs.Count; i++) { var inp = infInputs[i]; m_InputsByName[inp.name] = inp.data; } } public void Dispose() { if (m_Engine != null) m_Engine.Dispose(); m_TensorAllocator?.Reset(false); } void FetchBarracudaOutputs(string[] names) { m_TrainingOutputs.Clear(); foreach (var n in names) { var output = m_Engine.PeekOutput(n); m_TrainingOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n)); } } public void PutObservations(AgentInfo info, List sensors) { m_Infos.Add(new AgentInfoSensorsPair { agentInfo = info, sensors = sensors }); // We add the episodeId to this list to maintain the order in which the decisions were requested m_OrderedAgentsRequestingDecisions.Add(info.episodeId); if (!m_LastActionsReceived.ContainsKey(info.episodeId)) { m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty; } if (info.done) { // If the agent is done, we remove the key from the last action dictionary since no action // should be taken. m_LastActionsReceived.Remove(info.episodeId); } } public void GetObservationTensors(IReadOnlyList tensors, AgentInfo info, List sensors) { if (!m_ObservationsInitialized) { // Just grab the first agent in the collection (any will suffice, really). // We check for an empty Collection above, so this will always return successfully. m_TensorGenerator.InitializeObservations(sensors, m_TensorAllocator); m_ObservationsInitialized = true; } var infoSensorPair = new AgentInfoSensorsPair { agentInfo = info, sensors = sensors }; m_TensorGenerator.GenerateTensors(tensors, 1, new List { infoSensorPair }); } public IReadOnlyList GetInputTensors() { return m_Model.GetInputTensors(); } public void DecideBatch() { var currentBatchSize = m_Infos.Count; if (currentBatchSize == 0) { return; } if (!m_ObservationsInitialized) { // Just grab the first agent in the collection (any will suffice, really). // We check for an empty Collection above, so this will always return successfully. var firstInfo = m_Infos[0]; m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator); m_ObservationsInitialized = true; } // Prepare the input tensors to be feed into the engine m_TensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Infos); m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Buffer.SampleDummyBatch(currentBatchSize), m_TrainingState); PrepareBarracudaInputs(m_TrainingInputs); // Execute the Model m_Engine.Execute(m_InputsByName); FetchBarracudaOutputs(m_OutputNames); // Update the outputs m_TensorApplier.ApplyTensors(m_TrainingOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); m_Infos.Clear(); m_OrderedAgentsRequestingDecisions.Clear(); } public void UpdateModel(List transitions) { var currentBatchSize = transitions.Count; if (currentBatchSize == 0) { return; } m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions, m_TrainingState, true); PrepareBarracudaInputs(m_TrainingInputs); // Execute the Model m_Engine.Execute(m_InputsByName); // Update the model FetchBarracudaOutputs(new string[] { TensorNames.TrainingStateOut }); m_TrainingState = m_TrainingOutputs[0]; } public ActionBuffers GetAction(int agentId) { if (m_LastActionsReceived.ContainsKey(agentId)) { return m_LastActionsReceived[agentId]; } return ActionBuffers.Empty; } // void PrintTensor(TensorProxy tensor) // { // Debug.Log($"Print tensor {tensor.name}"); // for (var b = 0; b < tensor.data.batch; b++) // { // var message = new List(); // for (var i = 0; i < tensor.data.height; i++) // { // for (var j = 0; j < tensor.data.width; j++) // { // for(var k = 0; k < tensor.data.channels; k++) // { // message.Add(tensor.data[b, i, j, k]); // } // } // } // Debug.Log(string.Join(", ", message)); // } // } } }