using System.Collections.Generic; using System.Linq; using Unity.Barracuda; namespace Unity.MLAgents.Inference { /// /// Barracuda Model extension methods. /// internal static class BarracudaModelExtensions { /// /// Get array of the input tensor names of the model. /// /// /// The Barracuda engine model for loading static parameters. /// /// Array of the input tensor names of the model public static string[] GetInputNames(this Model model) { var names = new List(); if (model == null) return names.ToArray(); foreach (var input in model.inputs) { names.Add(input.name); } foreach (var mem in model.memories) { names.Add(mem.input); } names.Sort(); return names.ToArray(); } /// /// Generates the Tensor inputs that are expected to be present in the Model. /// /// /// The Barracuda engine model for loading static parameters. /// /// TensorProxy IEnumerable with the expected Tensor inputs. public static IReadOnlyList GetInputTensors(this Model model) { var tensors = new List(); if (model == null) return tensors; foreach (var input in model.inputs) { tensors.Add(new TensorProxy { name = input.name, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = input.shape.Select(i => (long)i).ToArray() }); } foreach (var mem in model.memories) { tensors.Add(new TensorProxy { name = mem.input, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = TensorUtils.TensorShapeFromBarracuda(mem.shape) }); } tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name)); return tensors; } /// /// Get number of visual observation inputs to the model. /// /// /// The Barracuda engine model for loading static parameters. /// /// Number of visual observation inputs to the model public static int GetNumVisualInputs(this Model model) { var count = 0; if (model == null) return count; foreach (var input in model.inputs) { if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix)) { count++; } } return count; } /// /// Get array of the output tensor names of the model. /// /// /// The Barracuda engine model for loading static parameters. /// /// Array of the output tensor names of the model public static string[] GetOutputNames(this Model model) { var names = new List(); if (model == null) { return names.ToArray(); } if (model.HasContinuousOutputs()) { names.Add(model.ContinuousOutputName()); } if (model.HasDiscreteOutputs()) { names.Add(model.DiscreteOutputName()); } var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; if (memory > 0) { foreach (var mem in model.memories) { names.Add(mem.output); } } names.Sort(); return names.ToArray(); } /// /// Check if the model has continuous action outputs. /// /// /// The Barracuda engine model for loading static parameters. /// /// True if the model has continuous action outputs. public static bool HasContinuousOutputs(this Model model) { if (model == null) return false; if (!model.SupportsContinuousAndDiscrete()) { return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0; } else { return model.outputs.Contains(TensorNames.ContinuousActionOutput) && (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; } } /// /// Continuous action output size of the model. /// /// /// The Barracuda engine model for loading static parameters. /// /// Size of continuous action output. public static int ContinuousOutputSize(this Model model) { if (model == null) return 0; if (!model.SupportsContinuousAndDiscrete()) { return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0; } else { var continuousOutputShape = model.GetTensorByName(TensorNames.ContinuousActionOutputShape); return continuousOutputShape == null ? 0 : (int)continuousOutputShape[0]; } } /// /// Continuous action output tensor name of the model. /// /// /// The Barracuda engine model for loading static parameters. /// /// Tensor name of continuous action output. public static string ContinuousOutputName(this Model model) { if (model == null) return null; if (!model.SupportsContinuousAndDiscrete()) { return TensorNames.ActionOutputDeprecated; } else { return TensorNames.ContinuousActionOutput; } } /// /// Check if the model has discrete action outputs. /// /// /// The Barracuda engine model for loading static parameters. /// /// True if the model has discrete action outputs. public static bool HasDiscreteOutputs(this Model model) { if (model == null) return false; if (!model.SupportsContinuousAndDiscrete()) { return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0; } else { return model.outputs.Contains(TensorNames.DiscreteActionOutput) && (int)model.GetTensorByName(TensorNames.DiscreteActionOutputShape)[0] > 0; } } /// /// Discrete action output size of the model. This is equal to the sum of the branch sizes. /// /// /// The Barracuda engine model for loading static parameters. /// /// Size of discrete action output. public static int DiscreteOutputSize(this Model model) { if (model == null) return 0; if (!model.SupportsContinuousAndDiscrete()) { return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? 0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0]; } else { var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape); return discreteOutputShape == null ? 0 : (int)discreteOutputShape[0]; } } /// /// Discrete action output tensor name of the model. /// /// /// The Barracuda engine model for loading static parameters. /// /// Tensor name of discrete action output. public static string DiscreteOutputName(this Model model) { if (model == null) return null; if (!model.SupportsContinuousAndDiscrete()) { return TensorNames.ActionOutputDeprecated; } else { return TensorNames.DiscreteActionOutput; } } /// /// Check if the model supports both continuous and discrete actions. /// If not, the model should be handled differently and use the deprecated fields. /// /// /// The Barracuda engine model for loading static parameters. /// /// True if the model supports both continuous and discrete actions. public static bool SupportsContinuousAndDiscrete(this Model model) { return model == null || model.outputs.Contains(TensorNames.ContinuousActionOutput) || model.outputs.Contains(TensorNames.DiscreteActionOutput); } /// /// Check if the model contains all the expected input/output tensors. /// /// /// The Barracuda engine model for loading static parameters. /// /// True if the model contains all the expected tensors. public static bool CheckExpectedTensors(this Model model, List failedModelChecks) { // Check the presence of model version var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); if (modelApiVersionTensor == null) { failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file."); return false; } // Check the presence of memory size var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize); if (memorySizeTensor == null) { failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file."); return false; } // Check the presence of action output tensor if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && !model.outputs.Contains(TensorNames.ContinuousActionOutput) && !model.outputs.Contains(TensorNames.DiscreteActionOutput)) { failedModelChecks.Add("The model does not contain any Action Output Node."); return false; } // Check the presence of action output shape tensor if (!model.SupportsContinuousAndDiscrete()) { if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null) { failedModelChecks.Add("The model does not contain any Action Output Shape Node."); return false; } if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null) { failedModelChecks.Add($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was not found in the model file. " + "This is only required for model that uses a deprecated model format."); return false; } } else { if (model.outputs.Contains(TensorNames.ContinuousActionOutput) && model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) { failedModelChecks.Add("The model uses continuous action but does not contain Continuous Action Output Shape Node."); return false; } if (model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) { failedModelChecks.Add("The model uses discrete action but does not contain Discrete Action Output Shape Node."); return false; } } return true; } } }