using System;
using System.Collections.Generic;
using System.Linq;
using Unity.Barracuda;
using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck;
namespace Unity.MLAgents.Inference
{
///
/// Barracuda Model extension methods.
///
internal static class BarracudaModelExtensions
{
///
/// Get array of the input tensor names of the model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Array of the input tensor names of the model
public static string[] GetInputNames(this Model model)
{
var names = new List();
if (model == null)
return names.ToArray();
foreach (var input in model.inputs)
{
names.Add(input.name);
}
foreach (var mem in model.memories)
{
names.Add(mem.input);
}
names.Sort(StringComparer.InvariantCulture);
return names.ToArray();
}
///
/// Get the version of the model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// The api version of the model
public static int GetVersion(this Model model)
{
return (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
}
///
/// Generates the Tensor inputs that are expected to be present in the Model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// TensorProxy IEnumerable with the expected Tensor inputs.
public static IReadOnlyList GetInputTensors(this Model model)
{
var tensors = new List();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
foreach (var mem in model.memories)
{
tensors.Add(new TensorProxy
{
name = mem.input,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = TensorUtils.TensorShapeFromBarracuda(mem.shape)
});
}
tensors.Sort((el1, el2) => string.Compare(el1.name, el2.name, StringComparison.InvariantCulture));
return tensors;
}
///
/// Get number of visual observation inputs to the model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Number of visual observation inputs to the model
public static int GetNumVisualInputs(this Model model)
{
var count = 0;
if (model == null)
return count;
foreach (var input in model.inputs)
{
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
{
count++;
}
}
return count;
}
///
/// Get array of the output tensor names of the model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Array of the output tensor names of the model
public static string[] GetOutputNames(this Model model)
{
var names = new List();
if (model == null)
{
return names.ToArray();
}
if (model.HasContinuousOutputs())
{
names.Add(model.ContinuousOutputName());
}
if (model.HasDiscreteOutputs())
{
names.Add(model.DiscreteOutputName());
}
var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memory > 0)
{
foreach (var mem in model.memories)
{
names.Add(mem.output);
}
}
names.Sort(StringComparer.InvariantCulture);
return names.ToArray();
}
///
/// Check if the model has continuous action outputs.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// True if the model has continuous action outputs.
public static bool HasContinuousOutputs(this Model model)
{
if (model == null)
return false;
if (!model.SupportsContinuousAndDiscrete())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0;
}
else
{
return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
}
}
///
/// Continuous action output size of the model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Size of continuous action output.
public static int ContinuousOutputSize(this Model model)
{
if (model == null)
return 0;
if (!model.SupportsContinuousAndDiscrete())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ?
(int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0;
}
else
{
var continuousOutputShape = model.GetTensorByName(TensorNames.ContinuousActionOutputShape);
return continuousOutputShape == null ? 0 : (int)continuousOutputShape[0];
}
}
///
/// Continuous action output tensor name of the model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Tensor name of continuous action output.
public static string ContinuousOutputName(this Model model)
{
if (model == null)
return null;
if (!model.SupportsContinuousAndDiscrete())
{
return TensorNames.ActionOutputDeprecated;
}
else
{
return TensorNames.ContinuousActionOutput;
}
}
///
/// Check if the model has discrete action outputs.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// True if the model has discrete action outputs.
public static bool HasDiscreteOutputs(this Model model)
{
if (model == null)
return false;
if (!model.SupportsContinuousAndDiscrete())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0;
}
else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
(int)model.DiscreteOutputSize() > 0;
}
}
///
/// Discrete action output size of the model. This is equal to the sum of the branch sizes.
/// This method gets the tensor representing the list of branch size and returns the
/// sum of all the elements in the Tensor.
/// - In version 1.X this tensor contains a single number, the sum of all branch
/// size values.
/// - In version 2.X this tensor contains a 1D Tensor with each element corresponding
/// to a branch size.
/// Since this method does the sum of all elements in the tensor, the output
/// will be the same on both 1.X and 2.X.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Size of discrete action output.
public static int DiscreteOutputSize(this Model model)
{
if (model == null)
return 0;
if (!model.SupportsContinuousAndDiscrete())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ?
0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0];
}
else
{
var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape);
if (discreteOutputShape == null)
{
return 0;
}
else
{
int result = 0;
for (int i = 0; i < discreteOutputShape.length; i++)
{
result += (int)discreteOutputShape[i];
}
return result;
}
}
}
///
/// Discrete action output tensor name of the model.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Tensor name of discrete action output.
public static string DiscreteOutputName(this Model model)
{
if (model == null)
return null;
if (!model.SupportsContinuousAndDiscrete())
{
return TensorNames.ActionOutputDeprecated;
}
else
{
return TensorNames.DiscreteActionOutput;
}
}
///
/// Check if the model supports both continuous and discrete actions.
/// If not, the model should be handled differently and use the deprecated fields.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// True if the model supports both continuous and discrete actions.
public static bool SupportsContinuousAndDiscrete(this Model model)
{
return model == null ||
model.outputs.Contains(TensorNames.ContinuousActionOutput) ||
model.outputs.Contains(TensorNames.DiscreteActionOutput);
}
///
/// Check if the model contains all the expected input/output tensors.
///
///
/// The Barracuda engine model for loading static parameters.
///
/// Output list of failure messages
///
/// True if the model contains all the expected tensors.
public static bool CheckExpectedTensors(this Model model, List failedModelChecks)
{
// Check the presence of model version
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
if (modelApiVersionTensor == null)
{
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
);
return false;
}
// Check the presence of memory size
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
if (memorySizeTensor == null)
{
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
);
return false;
}
// Check the presence of action output tensor
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Node.")
);
return false;
}
// Check the presence of action output shape tensor
if (!model.SupportsContinuousAndDiscrete())
{
if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Shape Node.")
);
return false;
}
if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null)
{
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " +
"not found in the model file. " +
"This is only required for model that uses a deprecated model format.")
);
return false;
}
}
else
{
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
);
return false;
}
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
);
return false;
}
}
return true;
}
}
}