using System; using System.Collections.Generic; using System.Linq; using Unity.Barracuda; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using Unity.MLAgents.Policies; namespace Unity.MLAgents.Inference { /// /// Prepares the Tensors for the Learning Brain and exposes a list of failed checks if Model /// and BrainParameters are incompatible. /// internal class BarracudaModelParamLoader { const long k_ApiVersion = 2; /// /// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks /// on it. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Attached sensor components /// Attached actuator components /// Sum of the sizes of all ObservableAttributes. /// BehaviorType or the Agent to check. /// The list the error messages of the checks that failed public static IEnumerable CheckModel(Model model, BrainParameters brainParameters, ISensor[] sensors, ActuatorComponent[] actuatorComponents, int observableAttributeTotalSize = 0, BehaviorType behaviorType = BehaviorType.Default) { List failedModelChecks = new List(); if (model == null) { var errorMsg = "There is no model for this Brain; cannot run inference. "; if (behaviorType == BehaviorType.InferenceOnly) { errorMsg += "Either assign a model, or change to a different Behavior Type."; } else { errorMsg += "(But can still train)"; } failedModelChecks.Add(errorMsg); return failedModelChecks; } var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks); if (!hasExpectedTensors) { return failedModelChecks; } var modelApiVersion = (int)model.GetTensorByName(TensorNames.VersionNumber)[0]; if (modelApiVersion == -1) { failedModelChecks.Add( "Model was not trained using the right version of ML-Agents. " + "Cannot use this model."); return failedModelChecks; } if (modelApiVersion != k_ApiVersion) { failedModelChecks.Add( $"Version of the trainer the model was trained with ({modelApiVersion}) " + $"is not compatible with the Brain's version ({k_ApiVersion})."); return failedModelChecks; } var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; if (memorySize == -1) { failedModelChecks.Add($"Missing node in the model provided : {TensorNames.MemorySize}"); return failedModelChecks; } failedModelChecks.AddRange( CheckInputTensorPresence(model, brainParameters, memorySize, sensors) ); failedModelChecks.AddRange( CheckOutputTensorPresence(model, memorySize) ); failedModelChecks.AddRange( CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize) ); failedModelChecks.AddRange( CheckOutputTensorShape(model, brainParameters, actuatorComponents) ); return failedModelChecks; } /// /// Generates failed checks that correspond to inputs expected by the model that are not /// present in the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// /// The memory size that the model is expecting. /// /// Array of attached sensor components /// /// A IEnumerable of string corresponding to the failed input presence checks. /// static IEnumerable CheckInputTensorPresence( Model model, BrainParameters brainParameters, int memory, ISensor[] sensors ) { var failedModelChecks = new List(); var tensorsNames = model.GetInputNames(); // If there is no Vector Observation Input but the Brain Parameters expect one. if ((brainParameters.VectorObservationSize != 0) && (!tensorsNames.Contains(TensorNames.VectorObservationPlaceholder))) { failedModelChecks.Add( "The model does not contain a Vector Observation Placeholder Input. " + "You must set the Vector Observation Space Size to 0."); } // If there are not enough Visual Observation Input compared to what the // sensors expect. var visObsIndex = 0; for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) { var sensor = sensors[sensorIndex]; if (sensor.GetObservationShape().Length == 3) { if (!tensorsNames.Contains( TensorNames.GetVisualObservationName(visObsIndex))) { failedModelChecks.Add( "The model does not contain a Visual Observation Placeholder Input " + $"for sensor component {visObsIndex} ({sensor.GetType().Name})."); } visObsIndex++; } if (sensor.GetObservationShape().Length == 2) { if (!tensorsNames.Contains( TensorNames.GetObservationName(sensorIndex))) { failedModelChecks.Add( "The model does not contain an Observation Placeholder Input " + $"for sensor component {sensorIndex} ({sensor.GetType().Name})."); } } } var expectedVisualObs = model.GetNumVisualInputs(); // Check if there's not enough visual sensors (too many would be handled above) if (expectedVisualObs > visObsIndex) { failedModelChecks.Add( $"The model expects {expectedVisualObs} visual inputs," + $" but only found {visObsIndex} visual sensors." ); } // If the model has a non-negative memory size but requires a recurrent input if (memory > 0) { if (!tensorsNames.Any(x => x.EndsWith("_h")) || !tensorsNames.Any(x => x.EndsWith("_c"))) { failedModelChecks.Add( "The model does not contain a Recurrent Input Node but has memory_size."); } } // If the model uses discrete control but does not have an input for action masks if (model.HasDiscreteOutputs()) { if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder)) { failedModelChecks.Add( "The model does not contain an Action Mask but is using Discrete Control."); } } return failedModelChecks; } /// /// Generates failed checks that correspond to outputs expected by the model that are not /// present in the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// The memory size that the model is expecting/ /// /// A IEnumerable of string corresponding to the failed output presence checks. /// static IEnumerable CheckOutputTensorPresence(Model model, int memory) { var failedModelChecks = new List(); // If there is no Recurrent Output but the model is Recurrent. if (memory > 0) { var memOutputs = model.memories.Select(x => x.output).ToList(); if (!memOutputs.Any(x => x.EndsWith("_h")) || !memOutputs.Any(x => x.EndsWith("_c"))) { failedModelChecks.Add( "The model does not contain a Recurrent Output Node but has memory_size."); } } return failedModelChecks; } /// /// Checks that the shape of the visual observation input placeholder is the same as the corresponding sensor. /// /// The tensor that is expected by the model /// The sensor that produces the visual observation. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// static string CheckVisualObsShape( TensorProxy tensorProxy, ISensor sensor) { var shape = sensor.GetObservationShape(); var heightBp = shape[0]; var widthBp = shape[1]; var pixelBp = shape[2]; var heightT = tensorProxy.Height; var widthT = tensorProxy.Width; var pixelT = tensorProxy.Channels; if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT)) { return $"The visual Observation of the model does not match. " + $"Received TensorProxy of shape [?x{widthBp}x{heightBp}x{pixelBp}] but " + $"was expecting [?x{widthT}x{heightT}x{pixelT}]."; } return null; } /// /// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor. /// /// The tensor that is expected by the model /// The sensor that produces the visual observation. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// static string CheckRankTwoObsShape( TensorProxy tensorProxy, ISensor sensor) { var shape = sensor.GetObservationShape(); var dim1Bp = shape[0]; var dim2Bp = shape[1]; var dim1T = tensorProxy.Channels; var dim2T = tensorProxy.Width; if ((dim1Bp != dim1T) || (dim2Bp != dim2T)) { return $"An Observation of the model does not match. " + $"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " + $"was expecting [?x{dim1T}x{dim2T}]."; } return null; } /// /// Generates failed checks that correspond to inputs shapes incompatibilities between /// the model and the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Attached sensors /// Sum of the sizes of all ObservableAttributes. /// The list the error messages of the checks that failed static IEnumerable CheckInputTensorShape( Model model, BrainParameters brainParameters, ISensor[] sensors, int observableAttributeTotalSize) { var failedModelChecks = new List(); var tensorTester = new Dictionary>() { {TensorNames.VectorObservationPlaceholder, CheckVectorObsShape}, {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)}, {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)}, {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)}, {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)}, }; foreach (var mem in model.memories) { tensorTester[mem.input] = ((bp, tensor, scs, i) => null); } var visObsIndex = 0; for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) { var sens = sensors[sensorIndex]; if (sens.GetObservationShape().Length == 3) { tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] = (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens); visObsIndex++; } if (sens.GetObservationShape().Length == 2) { tensorTester[TensorNames.GetObservationName(sensorIndex)] = (bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens); } } // If the model expects an input but it is not in this list foreach (var tensor in model.GetInputTensors()) { if (!tensorTester.ContainsKey(tensor.name)) { if (!tensor.name.Contains("visual_observation")) { failedModelChecks.Add( "Model requires an unknown input named : " + tensor.name); } } else { var tester = tensorTester[tensor.name]; var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize); if (error != null) { failedModelChecks.Add(error); } } } return failedModelChecks; } /// /// Checks that the shape of the Vector Observation input placeholder is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor that is expected by the model /// Array of attached sensor components /// Sum of the sizes of all ObservableAttributes. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// static string CheckVectorObsShape( BrainParameters brainParameters, TensorProxy tensorProxy, ISensor[] sensors, int observableAttributeTotalSize) { var vecObsSizeBp = brainParameters.VectorObservationSize; var numStackedVector = brainParameters.NumStackedVectorObservations; var totalVecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1]; var totalVectorSensorSize = 0; foreach (var sens in sensors) { if ((sens.GetObservationShape().Length == 1)) { totalVectorSensorSize += sens.GetObservationShape()[0]; } } if (totalVectorSensorSize != totalVecObsSizeT) { var sensorSizes = ""; foreach (var sensorComp in sensors) { if (sensorComp.GetObservationShape().Length == 1) { var vecSize = sensorComp.GetObservationShape()[0]; if (sensorSizes.Length == 0) { sensorSizes = $"[{vecSize}"; } else { sensorSizes += $", {vecSize}"; } } } sensorSizes += "]"; return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " + $"but received: \n" + $"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" + $"Total [Observable] attributes: {observableAttributeTotalSize}\n" + $"Sensor sizes: {sensorSizes}."; } return null; } /// /// Checks that the shape of the Previous Vector Action input placeholder is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// The tensor that is expected by the model /// Array of attached sensor components (unused). /// Sum of the sizes of all ObservableAttributes (unused). /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. static string CheckPreviousActionShape( BrainParameters brainParameters, TensorProxy tensorProxy, ISensor[] sensors, int observableAttributeTotalSize) { var numberActionsBp = brainParameters.ActionSpec.NumDiscreteActions; var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1]; if (numberActionsBp != numberActionsT) { return "Previous Action Size of the model does not match. " + $"Received {numberActionsBp} but was expecting {numberActionsT}."; } return null; } /// /// Generates failed checks that correspond to output shapes incompatibilities between /// the model and the BrainParameters. /// /// /// The Barracuda engine model for loading static parameters /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Array of attached actuator components. /// /// A IEnumerable of string corresponding to the incompatible shapes between model /// and BrainParameters. /// static IEnumerable CheckOutputTensorShape( Model model, BrainParameters brainParameters, ActuatorComponent[] actuatorComponents) { var failedModelChecks = new List(); // If the model expects an output but it is not in this list var modelContinuousActionSize = model.ContinuousOutputSize(); var continuousError = CheckContinuousActionOutputShape(brainParameters, actuatorComponents, modelContinuousActionSize); if (continuousError != null) { failedModelChecks.Add(continuousError); } var modelSumDiscreteBranchSizes = model.DiscreteOutputSize(); var discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes); if (discreteError != null) { failedModelChecks.Add(discreteError); } return failedModelChecks; } /// /// Checks that the shape of the discrete action output is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Array of attached actuator components. /// /// The size of the discrete action output that is expected by the model. /// /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. /// static string CheckDiscreteActionOutputShape( BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelSumDiscreteBranchSizes) { // TODO: check each branch size instead of sum of branch sizes var sumOfDiscreteBranchSizes = brainParameters.ActionSpec.SumOfDiscreteBranchSizes; foreach (var actuatorComponent in actuatorComponents) { var actionSpec = actuatorComponent.ActionSpec; sumOfDiscreteBranchSizes += actionSpec.SumOfDiscreteBranchSizes; } if (modelSumDiscreteBranchSizes != sumOfDiscreteBranchSizes) { return "Discrete Action Size of the model does not match. The BrainParameters expect " + $"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}."; } return null; } /// /// Checks that the shape of the continuous action output is the same in the /// model and in the Brain Parameters. /// /// /// The BrainParameters that are used verify the compatibility with the InferenceEngine /// /// Array of attached actuator components. /// /// The size of the continuous action output that is expected by the model. /// /// If the Check failed, returns a string containing information about why the /// check failed. If the check passed, returns null. static string CheckContinuousActionOutputShape( BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelContinuousActionSize) { var numContinuousActions = brainParameters.ActionSpec.NumContinuousActions; foreach (var actuatorComponent in actuatorComponents) { var actionSpec = actuatorComponent.ActionSpec; numContinuousActions += actionSpec.NumContinuousActions; } if (modelContinuousActionSize != numContinuousActions) { return "Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " + $"{numContinuousActions} but the model contains {modelContinuousActionSize}."; } return null; } } }