#define ENABLE_BARRACUDA
using System;
using System.Collections.Generic;
using UnityEngine;
using System.Linq;
using Barracuda;
using MLAgents.InferenceBrain;
using UnityEngine.Profiling;
using Tensor = MLAgents.InferenceBrain.Tensor;
namespace MLAgents
{
public enum InferenceDevice
{
CPU = 0,
GPU = 1
}
///
/// The Learning Brain works differently if you are training it or not.
/// When training your Agents, drag the Learning Brain to the Academy's BroadcastHub and check
/// the checkbox Control. When using a pretrained model, just drag the Model file into the
/// Model property of the Learning Brain.
/// The property model corresponds to the Model currently attached to the Brain. Before
/// being used, a call to ReloadModel is required.
/// When the Learning Brain is not training, it uses a TensorFlow model to make decisions.
/// The Proximal Policy Optimization (PPO) and Behavioral Cloning algorithms included with
/// the ML-Agents SDK produce trained TensorFlow models that you can use with the
/// Learning Brain.
///
[CreateAssetMenu(fileName = "NewLearningBrain", menuName = "ML-Agents/Learning Brain")]
public class LearningBrain : Brain
{
private TensorGenerator _tensorGenerator;
private TensorApplier _tensorApplier;
#if ENABLE_TENSORFLOW
public TextAsset model;
private ModelParamLoader _modelParamLoader;
private TFSharpInferenceEngine _engine;
#elif ENABLE_BARRACUDA
public NNModel model;
private Model _barracudaModel;
private IWorker _engine;
private bool _verbose = false;
private BarracudaModelParamLoader _modelParamLoader;
private string[] _outputNames;
#endif
[Tooltip("Inference execution device. CPU is the fastest option for most of ML Agents models. " +
"(This field is not applicable for training).")]
public InferenceDevice inferenceDevice = InferenceDevice.CPU;
private IReadOnlyList _inferenceInputs;
private IReadOnlyList _inferenceOutputs;
[NonSerialized]
private bool _isControlled;
///
/// When Called, the brain will be controlled externally. It will not use the
/// model to decide on actions.
///
public void SetToControlledExternally()
{
_isControlled = true;
}
///
protected override void Initialize()
{
ReloadModel();
}
///
/// Initializes the Brain with the Model that it will use when selecting actions for
/// the agents
///
/// The seed that will be used to initialize the RandomNormal
/// and Multinomial obsjects used when running inference.
/// Throws an error when the model is null
///
public void ReloadModel(int seed = 0)
{
#if ENABLE_TENSORFLOW
if (model != null)
{
_engine = new TFSharpInferenceEngine();
_engine.PrepareModel(model.bytes);
}
else
{
_engine = null;
}
_modelParamLoader = ModelParamLoader.GetLoaderAndCheck(_engine, brainParameters);
_inferenceInputs = _modelParamLoader.GetInputTensors();
_inferenceOutputs = _modelParamLoader.GetOutputTensors();
_tensorGenerator = new TensorGenerator(brainParameters, seed);
_tensorApplier = new TensorApplier(brainParameters, seed);
#elif ENABLE_BARRACUDA
if (model != null)
{
#if BARRACUDA_VERBOSE
_verbose = true;
#endif
D.logEnabled = _verbose;
// Cleanup previous instance
if (_engine != null)
_engine.Dispose();
_barracudaModel = ModelLoader.Load(model.Value);
var executionDevice = inferenceDevice == InferenceDevice.GPU
? BarracudaWorkerFactory.Type.ComputeFast
: BarracudaWorkerFactory.Type.CSharpFast;
_engine = BarracudaWorkerFactory.CreateWorker(executionDevice, _barracudaModel, _verbose);
}
else
{
_barracudaModel = null;
_engine = null;
}
_modelParamLoader = BarracudaModelParamLoader.GetLoaderAndCheck(_engine, _barracudaModel, brainParameters);
_inferenceInputs = _modelParamLoader.GetInputTensors();
_outputNames = _modelParamLoader.GetOutputNames();
_tensorGenerator = new TensorGenerator(brainParameters, seed);
_tensorApplier = new TensorApplier(brainParameters, seed);
#endif
}
///
/// Return a list of failed checks corresponding to the failed compatibility checks
/// between the Model and the BrainParameters. Note : This does not reload the model.
/// If changes have been made to the BrainParameters or the Model, the model must be
/// reloaded using GiveModel before trying to get the compatibility checks.
///
/// The list of the failed compatibility checks between the Model and the
/// Brain Parameters
public IEnumerable GetModelFailedChecks()
{
#if ENABLE_TENSORFLOW
return (_modelParamLoader != null) ? _modelParamLoader.GetChecks() : new List();
#elif ENABLE_BARRACUDA
return (_modelParamLoader != null) ? _modelParamLoader.GetChecks() : new List();
#else
return new List(){
"You need to install the TensorflowSharp plugin and add the ENABLE_TENSORFLOW " +
"flag in your Player Settings in order to use inference. "};
#endif
}
///
protected override void DecideAction()
{
if (_isControlled)
{
agentInfos.Clear();
return;
}
var currentBatchSize = agentInfos.Count();
if (currentBatchSize == 0)
{
return;
}
#if ENABLE_TENSORFLOW
if (_engine == null)
{
Debug.LogError($"No model was present for the Brain {name}.");
return;
}
// Prepare the input tensors to be feed into the engine
_tensorGenerator.GenerateTensors(_inferenceInputs, currentBatchSize, agentInfos);
// Prepare the output tensors to be feed into the engine
_tensorGenerator.GenerateTensors(_inferenceOutputs, currentBatchSize, agentInfos);
// Execute the Model
Profiler.BeginSample($"MLAgents.{name}.ExecuteGraph");
_engine.ExecuteGraph(_inferenceInputs, _inferenceOutputs);
Profiler.EndSample();
// Update the outputs
_tensorApplier.ApplyTensors(_inferenceOutputs, agentInfos);
#elif ENABLE_BARRACUDA
if (_engine == null)
{
Debug.LogError($"No model was present for the Brain {name}.");
return;
}
// Prepare the input tensors to be feed into the engine
_tensorGenerator.GenerateTensors(_inferenceInputs, currentBatchSize, agentInfos);
var inputs = PrepareBarracudaInputs(_inferenceInputs);
// Execute the Model
Profiler.BeginSample($"MLAgents.{name}.ExecuteGraph");
_engine.Execute(inputs);
Profiler.EndSample();
_inferenceOutputs = FetchBarracudaOutputs(_outputNames);
CleanupBarracudaState(inputs);
// Update the outputs
_tensorApplier.ApplyTensors(_inferenceOutputs, agentInfos);
#else
if (agentInfos.Count > 0)
{
Debug.LogError(string.Format(
"The brain {0} was set to inference mode but the Tensorflow library is not " +
"present in the Unity project.",
name));
}
#endif
agentInfos.Clear();
}
#if ENABLE_BARRACUDA && !ENABLE_TENSORFLOW
protected Dictionary PrepareBarracudaInputs(IEnumerable infInputs)
{
var inputs = new Dictionary();
foreach (var inp in _inferenceInputs)
{
inputs[inp.Name] = BarracudaUtils.ToBarracuda(inp);
}
return inputs;
}
protected List FetchBarracudaOutputs(string[] names)
{
var outputs = new List();
foreach (var name in names)
{
var outp = _engine.Fetch(name);
outputs.Add(BarracudaUtils.FromBarracuda(outp, name));
outp.Dispose();
}
return outputs;
}
protected void CleanupBarracudaState(Dictionary inputs)
{
foreach (var key in inputs.Keys)
{
inputs[key].Dispose();
}
inputs.Clear();
}
public void OnDisable()
{
_engine?.Dispose();
}
#endif
}
}