using System; using System.Collections.Generic; using System.Linq; using Barracuda; namespace MLAgents.InferenceBrain { /// /// Prepares the Tensors for the Learning Brain and exposes a list of failed checks if Model /// and BrainParameters are incompatible. /// public class BarracudaModelParamLoader { private enum ModelActionType { Unknown, Discrete, Continuous } private const long k_ApiVersion = 2; /// /// Generates the Tensor inputs that are expected to be present in the Model. /// /// /// The Barracuda engine model for loading static parameters /// /// TensorProxy IEnumerable with the expected Tensor inputs public static IReadOnlyList GetInputTensors(Model model) { var tensors = new List(); if (model == null) return tensors; foreach (var input in model.inputs) { tensors.Add(new TensorProxy { name = input.name, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = input.shape.Select(i => (long)i).ToArray() }); } foreach (var mem in model.memories) { tensors.Add(new TensorProxy { name = mem.input, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = TensorUtils.TensorShapeFromBarracuda(mem.shape) }); } tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name)); return tensors; } /// /// Generates the Tensor outputs that are expected to be present in the Model. /// /// /// The Barracuda engine model for loading static parameters /// /// TensorProxy IEnumerable with the expected Tensor outputs public static string[] GetOutputNames(Model model) { var names = new List(); if (model == null) { return names.ToArray(); } names.Add(TensorNames.ActionOutput); var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; if (memory > 0) { foreach (var mem in model.memories) { names.Add(mem.output); } } names.Sort(); return names.ToArray(); } /// /// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks /// on it. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The list the error messages of the checks that failed public static IEnumerable CheckModel(Model model, BrainParameters brainParameters) { List failedModelChecks = new List(); if (model == null) { failedModelChecks.Add( "There is no model for this Brain, cannot run inference. " + "(But can still train)"); return failedModelChecks; } var modelApiVersion = (int)model.GetTensorByName(TensorNames.VersionNumber)[0]; var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; var isContinuousInt = (int)model.GetTensorByName(TensorNames.IsContinuousControl)[0]; var isContinuous = GetActionType(isContinuousInt); var actionSize = (int)model.GetTensorByName(TensorNames.ActionOutputShape)[0]; if (modelApiVersion == -1) { failedModelChecks.Add( "Model was not trained using the right version of ML-Agents. " + "Cannot use this model."); return failedModelChecks; } if (modelApiVersion != k_ApiVersion) { failedModelChecks.Add( $"Version of the trainer the model was trained with ({modelApiVersion}) " + $"is not compatible with the Brain's version ({k_ApiVersion})."); return failedModelChecks; } failedModelChecks.AddRange( CheckIntScalarPresenceHelper(new Dictionary() { {TensorNames.MemorySize, memorySize}, {TensorNames.IsContinuousControl, isContinuousInt}, {TensorNames.ActionOutputShape, actionSize} }) ); failedModelChecks.AddRange( CheckInputTensorPresence(model, brainParameters, memorySize, isContinuous) ); failedModelChecks.AddRange( CheckOutputTensorPresence(model, memorySize)) ; failedModelChecks.AddRange( CheckInputTensorShape(model, brainParameters) ); failedModelChecks.AddRange( CheckOutputTensorShape(model, brainParameters, isContinuous, actionSize) ); return failedModelChecks; } /// /// Converts the integer value in the model corresponding to the type of control to a /// ModelActionType. /// /// /// The integer value in the model indicating the type of control /// /// The equivalent ModelActionType private static ModelActionType GetActionType(int isContinuousInt) { ModelActionType isContinuous; switch (isContinuousInt) { case 0: isContinuous = ModelActionType.Discrete; break; case 1: isContinuous = ModelActionType.Continuous; break; default: isContinuous = ModelActionType.Unknown; break; } return isContinuous; } /// /// Given a Dictionary of node names to int values, create checks if the values have the /// invalid value of -1. /// /// Mapping from node names to int values /// The list the error messages of the checks that failed private static IEnumerable CheckIntScalarPresenceHelper( Dictionary requiredScalarFields) { var failedModelChecks = new List(); foreach (var field in requiredScalarFields) { if (field.Value == -1) { failedModelChecks.Add($"Missing node in the model provided : {field.Key}"); } } return failedModelChecks; } /// /// Generates failed checks that correspond to inputs expected by the model that are not /// present in the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// /// The memory size that the model is expecting. /// /// /// Whether the model is expecting continuous or discrete control. /// /// /// A IEnumerable of string corresponding to the failed input presence checks. /// private static IEnumerable CheckInputTensorPresence( Model model, BrainParameters brainParameters, int memory, ModelActionType isContinuous) { var failedModelChecks = new List(); var tensorsNames = GetInputTensors(model).Select(x => x.name).ToList(); // If there is no Vector Observation Input but the Brain Parameters expect one. if ((brainParameters.vectorObservationSize != 0) && (!tensorsNames.Contains(TensorNames.VectorObservationPlacholder))) { failedModelChecks.Add( "The model does not contain a Vector Observation Placeholder Input. " + "You must set the Vector Observation Space Size to 0."); } // If there are not enough Visual Observation Input compared to what the // Brain Parameters expect. for (var visObsIndex = 0; visObsIndex < brainParameters.cameraResolutions.Length; visObsIndex++) { if (!tensorsNames.Contains( TensorNames.VisualObservationPlaceholderPrefix + visObsIndex)) { failedModelChecks.Add( "The model does not contain a Visual Observation Placeholder Input " + "for visual observation " + visObsIndex + "."); } } // If the model has a non-negative memory size but requires a recurrent input if (memory > 0) { if (!tensorsNames.Any(x => x.EndsWith("_h")) || !tensorsNames.Any(x => x.EndsWith("_c"))) { failedModelChecks.Add( "The model does not contain a Recurrent Input Node but has memory_size."); } } // If the model uses discrete control but does not have an input for action masks if (isContinuous == ModelActionType.Discrete) { if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder)) { failedModelChecks.Add( "The model does not contain an Action Mask but is using Discrete Control."); } } return failedModelChecks; } /// /// Generates failed checks that correspond to outputs expected by the model that are not /// present in the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// The memory size that the model is expecting/ /// /// A IEnumerable of string corresponding to the failed output presence checks. /// private static IEnumerable CheckOutputTensorPresence(Model model, int memory) { var failedModelChecks = new List(); // If there is no Action Output. if (!model.outputs.Contains(TensorNames.ActionOutput)) { failedModelChecks.Add("The model does not contain an Action Output Node."); } // If there is no Recurrent Output but the model is Recurrent. if (memory > 0) { var memOutputs = model.memories.Select(x => x.output).ToList(); if (!memOutputs.Any(x => x.EndsWith("_h")) || !memOutputs.Any(x => x.EndsWith("_c"))) { failedModelChecks.Add( "The model does not contain a Recurrent Output Node but has memory_size."); } } return failedModelChecks; } /// /// Generates failed checks that correspond to inputs shapes incompatibilities between /// the model and the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The list the error messages of the checks that failed private static IEnumerable CheckInputTensorShape( Model model, BrainParameters brainParameters) { var failedModelChecks = new List(); var tensorTester = new Dictionary>() { {TensorNames.VectorObservationPlacholder, CheckVectorObsShape}, {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor) => null)}, {TensorNames.ActionMaskPlaceholder, ((bp, tensor) => null)}, {TensorNames.SequenceLengthPlaceholder, ((bp, tensor) => null)}, {TensorNames.RecurrentInPlaceholder, ((bp, tensor) => null)}, }; foreach (var mem in model.memories) { tensorTester[mem.input] = ((bp, tensor) => null); } for (var obsIndex = 0; obsIndex < brainParameters.cameraResolutions.Length; obsIndex++) { var index = obsIndex; tensorTester[TensorNames.VisualObservationPlaceholderPrefix + obsIndex] = (bp, tensor) => CheckVisualObsShape(bp, tensor, index); } // If the model expects an input but it is not in this list foreach (var tensor in GetInputTensors(model)) { if (!tensorTester.ContainsKey(tensor.name)) { failedModelChecks.Add( "Model requires an unknown input named : " + tensor.name); } else { var tester = tensorTester[tensor.name]; var error = tester.Invoke(brainParameters, tensor); if (error != null) { failedModelChecks.Add(error); } } } return failedModelChecks; } /// /// Checks that the shape of the Vector Observation input placeholder is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor that is expected by the model /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// private static string CheckVectorObsShape( BrainParameters brainParameters, TensorProxy tensorProxy) { var vecObsSizeBp = brainParameters.vectorObservationSize; var numStackedVector = brainParameters.numStackedVectorObservations; var totalVecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1]; if (vecObsSizeBp * numStackedVector != totalVecObsSizeT) { return "Vector Observation Size of the model does not match. Received " + $"{vecObsSizeBp} x {numStackedVector} but was expecting {totalVecObsSizeT}."; } return null; } /// /// Checks that the shape of the Previous Vector Action input placeholder is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor that is expected by the model /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. private static string CheckPreviousActionShape( BrainParameters brainParameters, TensorProxy tensorProxy) { var numberActionsBp = brainParameters.vectorActionSize.Length; var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1]; if (numberActionsBp != numberActionsT) { return "Previous Action Size of the model does not match. " + $"Received {numberActionsBp} but was expecting {numberActionsT}."; } return null; } /// /// Checks that the shape of the visual observation input placeholder is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor that is expected by the model /// The index of the visual observation. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// private static string CheckVisualObsShape( BrainParameters brainParameters, TensorProxy tensorProxy, int visObsIndex) { var resolutionBp = brainParameters.cameraResolutions[visObsIndex]; var widthBp = resolutionBp.width; var heightBp = resolutionBp.height; var pixelBp = resolutionBp.blackAndWhite ? 1 : 3; var heightT = tensorProxy.shape[1]; var widthT = tensorProxy.shape[2]; var pixelT = tensorProxy.shape[3]; if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT)) { return $"The visual Observation {visObsIndex} of the model does not match. " + $"Received TensorProxy of shape [?x{widthBp}x{heightBp}x{pixelBp}] but " + $"was expecting [?x{widthT}x{heightT}x{pixelT}]."; } return null; } /// /// Generates failed checks that correspond to output shapes incompatibilities between /// the model and the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// /// Whether the model is expecting continuous or discrete control. /// /// /// The size of the action output that is expected by the model. /// /// /// A IEnumerable of string corresponding to the incompatible shapes between model /// and BrainParameters. /// private static IEnumerable CheckOutputTensorShape( Model model, BrainParameters brainParameters, ModelActionType isContinuous, int modelActionSize) { var failedModelChecks = new List(); if (isContinuous == ModelActionType.Unknown) { failedModelChecks.Add("Cannot infer type of Control from the provided model."); return failedModelChecks; } if (isContinuous == ModelActionType.Continuous && brainParameters.vectorActionSpaceType != SpaceType.Continuous) { failedModelChecks.Add( "Model has been trained using Continuous Control but the Brain Parameters " + "suggest Discrete Control."); return failedModelChecks; } if (isContinuous == ModelActionType.Discrete && brainParameters.vectorActionSpaceType != SpaceType.Discrete) { failedModelChecks.Add( "Model has been trained using Discrete Control but the Brain Parameters " + "suggest Continuous Control."); return failedModelChecks; } var tensorTester = new Dictionary>(); if (brainParameters.vectorActionSpaceType == SpaceType.Continuous) { tensorTester[TensorNames.ActionOutput] = CheckContinuousActionOutputShape; } else { tensorTester[TensorNames.ActionOutput] = CheckDiscreteActionOutputShape; } // If the model expects an output but it is not in this list foreach (var name in model.outputs) { if (tensorTester.ContainsKey(name)) { var tester = tensorTester[name]; var error = tester.Invoke(brainParameters, model.GetShapeByName(name), modelActionSize); if (error != null) { failedModelChecks.Add(error); } } } return failedModelChecks; } /// /// Checks that the shape of the discrete action output is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor shape that is expected by the model /// /// The size of the action output that is expected by the model. /// /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// private static string CheckDiscreteActionOutputShape( BrainParameters brainParameters, TensorShape shape, int modelActionSize) { var bpActionSize = brainParameters.vectorActionSize.Sum(); if (modelActionSize != bpActionSize) { return "Action Size of the model does not match. The BrainParameters expect " + $"{bpActionSize} but the model contains {modelActionSize}."; } return null; } /// /// Checks that the shape of the continuous action output is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor shape that is expected by the model /// /// The size of the action output that is expected by the model. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. private static string CheckContinuousActionOutputShape( BrainParameters brainParameters, TensorShape shape, int modelActionSize) { var bpActionSize = brainParameters.vectorActionSize[0]; if (modelActionSize != bpActionSize) { return "Action Size of the model does not match. The BrainParameters expect " + $"{bpActionSize} but the model contains {modelActionSize}."; } return null; } } }