浏览代码

Refactor BarracudaModel loader checks (#4629)

* move model methods to BarracudaModelExtensions

* add method to check expected tensors in extensions
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
8669e389
共有 4 个文件被更改,包括 244 次插入168 次删除
  1. 217
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  2. 161
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  3. 4
      com.unity.ml-agents/Runtime/Inference/ModelRunner.cs
  4. 30
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs

217
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


using System.Collections.Generic;
using System.Linq;
using Unity.Barracuda;
namespace Unity.MLAgents.Inference

internal static class BarracudaModelExtensions
{
/// <summary>
/// Get array of the input tensor names of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>Array of the input tensor names of the model</returns>
public static string[] GetInputNames(this Model model)
{
var names = new List<string>();
if (model == null)
return names.ToArray();
foreach (var input in model.inputs)
{
names.Add(input.name);
}
foreach (var mem in model.memories)
{
names.Add(mem.input);
}
names.Sort();
return names.ToArray();
}
/// <summary>
/// Generates the Tensor inputs that are expected to be present in the Model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns>
public static IReadOnlyList<TensorProxy> GetInputTensors(this Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
foreach (var mem in model.memories)
{
tensors.Add(new TensorProxy
{
name = mem.input,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = TensorUtils.TensorShapeFromBarracuda(mem.shape)
});
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
/// <summary>
/// Get number of visual observation inputs to the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>Number of visual observation inputs to the model</returns>
public static int GetNumVisualInputs(this Model model)
{
var count = 0;
if (model == null)
return count;
foreach (var input in model.inputs)
{
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
{
count++;
}
}
return count;
}
/// <summary>
/// Get array of the output tensor names of the model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>Array of the output tensor names of the model</returns>
public static string[] GetOutputNames(this Model model)
{
var names = new List<string>();
if (model == null)
{
return names.ToArray();
}
if (model.HasContinuousOutputs())
{
names.Add(model.ContinuousOutputName());
}
if (model.HasDiscreteOutputs())
{
names.Add(model.DiscreteOutputName());
}
var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memory > 0)
{
foreach (var mem in model.memories)
{
names.Add(mem.output);
}
}
names.Sort();
return names.ToArray();
}
/// <summary>
/// Check if the model has continuous action outputs.
/// </summary>
/// <param name="model">

public static bool HasContinuousOutputs(this Model model)
{
if (model == null)
return false;
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0;

/// <returns>Size of continuous action output.</returns>
public static int ContinuousOutputSize(this Model model)
{
if (model == null)
return 0;
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ?

/// <returns>Tensor name of continuous action output.</returns>
public static string ContinuousOutputName(this Model model)
{
if (model == null)
return null;
if (model.UseDeprecated())
{
return TensorNames.ActionOutputDeprecated;

/// <returns>True if the model has discrete action outputs.</returns>
public static bool HasDiscreteOutputs(this Model model)
{
if (model == null)
return false;
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0;

/// <returns>Size of discrete action output.</returns>
public static int DiscreteOutputSize(this Model model)
{
if (model == null)
return 0;
if (model.UseDeprecated())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ?

/// <returns>Tensor name of discrete action output.</returns>
public static string DiscreteOutputName(this Model model)
{
if (model == null)
return null;
if (model.UseDeprecated())
{
return TensorNames.ActionOutputDeprecated;

/// <returns>True if the model uses deprecated output fields.</returns>
public static bool UseDeprecated(this Model model)
{
if (model == null)
return false;
}
/// <summary>
/// Check if the model contains all the expected input/output tensors.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>True if the model contains all the expected tensors.</returns>
public static bool CheckExpectedTensors(this Model model, List<string> failedModelChecks)
{
// Check the presence of model version
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
if (modelApiVersionTensor == null)
{
failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.");
return false;
}
// Check the presence of memory size
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
if (memorySizeTensor == null)
{
failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.");
return false;
}
// Check the presence of action output tensor
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add("The model does not contain any Action Output Node.");
return false;
}
// Check the presence of action output shape tensor
if (model.UseDeprecated())
{
if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
{
failedModelChecks.Add("The model does not contain any Action Output Shape Node.");
return false;
}
if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null)
{
failedModelChecks.Add($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was not found in the model file. " +
"This is only required for model that uses a deprecated model format.");
return false;
}
}
else
{
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
{
failedModelChecks.Add("The model uses continuous action but does not contain Continuous Action Output Shape Node.");
return false;
}
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
{
failedModelChecks.Add("The model uses discrete action but does not contain Discrete Action Output Shape Node.");
return false;
}
}
return true;
}
}
}

161
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


const long k_ApiVersion = 2;
/// <summary>
/// Generates the Tensor inputs that are expected to be present in the Model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns>
public static IReadOnlyList<TensorProxy> GetInputTensors(Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
foreach (var mem in model.memories)
{
tensors.Add(new TensorProxy
{
name = mem.input,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = TensorUtils.TensorShapeFromBarracuda(mem.shape)
});
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
public static int GetNumVisualInputs(Model model)
{
var count = 0;
if (model == null)
return count;
foreach (var input in model.inputs)
{
if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix))
{
count++;
}
}
return count;
}
/// <summary>
/// Generates the Tensor outputs that are expected to be present in the Model.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <returns>TensorProxy IEnumerable with the expected Tensor outputs</returns>
public static string[] GetOutputNames(Model model)
{
var names = new List<string>();
if (model == null)
{
return names.ToArray();
}
if (model.HasContinuousOutputs())
{
names.Add(model.ContinuousOutputName());
}
if (model.HasDiscreteOutputs())
{
names.Add(model.DiscreteOutputName());
}
var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memory > 0)
{
foreach (var mem in model.memories)
{
names.Add(mem.output);
}
}
names.Sort();
return names.ToArray();
}
/// <summary>
/// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks
/// on it.
/// </summary>

return failedModelChecks;
}
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
if (modelApiVersionTensor == null)
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
if (!hasExpectedTensors)
failedModelChecks.Add($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.");
var modelApiVersion = (int)modelApiVersionTensor[0];
var modelApiVersion = (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
if (modelApiVersion == -1)
{
failedModelChecks.Add(

return failedModelChecks;
}
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
if (memorySizeTensor == null)
var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0];
if (memorySize == -1)
failedModelChecks.Add($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.");
failedModelChecks.Add($"Missing node in the model provided : {TensorNames.MemorySize}");
var memorySize = (int)memorySizeTensor[0];
failedModelChecks.AddRange(
CheckIntScalarPresenceHelper(new Dictionary<string, int>()
{
{TensorNames.MemorySize, memorySize},
})
);
failedModelChecks.AddRange(
CheckInputTensorPresence(model, brainParameters, memorySize, sensorComponents)
);

return failedModelChecks;
}
/// <summary>
/// Given a Dictionary of node names to int values, create checks if the values have the
/// invalid value of -1.
/// </summary>
/// <param name="requiredScalarFields"> Mapping from node names to int values</param>
/// <returns>The list the error messages of the checks that failed</returns>
static IEnumerable<string> CheckIntScalarPresenceHelper(
Dictionary<string, int> requiredScalarFields)
{
var failedModelChecks = new List<string>();
foreach (var field in requiredScalarFields)
{
if (field.Value == -1)
{
failedModelChecks.Add($"Missing node in the model provided : {field.Key}");
}
}
return failedModelChecks;
}
/// <summary>
/// Generates failed checks that correspond to inputs expected by the model that are not
/// present in the BrainParameters.

)
{
var failedModelChecks = new List<string>();
var tensorsNames = GetInputTensors(model).Select(x => x.name).ToList();
var tensorsNames = model.GetInputNames();
// If there is no Vector Observation Input but the Brain Parameters expect one.
if ((brainParameters.VectorObservationSize != 0) &&

visObsIndex++;
}
var expectedVisualObs = GetNumVisualInputs(model);
var expectedVisualObs = model.GetNumVisualInputs();
// Check if there's not enough visual sensors (too many would be handled above)
if (expectedVisualObs > visObsIndex)
{

static IEnumerable<string> CheckOutputTensorPresence(Model model, int memory)
{
var failedModelChecks = new List<string>();
// If there is no Action Output.
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add("The model does not contain an Action Output Node.");
}
// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)

}
// If the model expects an input but it is not in this list
foreach (var tensor in GetInputTensors(model))
foreach (var tensor in model.GetInputTensors())
{
if (!tensorTester.ContainsKey(tensor.name))
{

ActuatorComponent[] actuatorComponents)
{
var failedModelChecks = new List<string>();
// Check the presence of action output shape
if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null &&
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null &&
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
{
failedModelChecks.Add("The model does not contain an Action Output Shape Node.");
return failedModelChecks;
}
var tensorTester = new Dictionary<string, Func<BrainParameters, ActuatorComponent[], TensorShape?, int, int, string>>();
if (model.HasContinuousOutputs())

4
com.unity.ml-agents/Runtime/Inference/ModelRunner.cs


m_Engine = null;
}
m_InferenceInputs = BarracudaModelParamLoader.GetInputTensors(barracudaModel);
m_OutputNames = BarracudaModelParamLoader.GetOutputNames(barracudaModel);
m_InferenceInputs = barracudaModel.GetInputTensors();
m_OutputNames = barracudaModel.GetOutputNames();
m_TensorGenerator = new TensorGenerator(
seed, m_TensorAllocator, m_Memories, barracudaModel);
m_TensorApplier = new TensorApplier(

30
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


public void TestGetInputTensorsContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var inputTensors = BarracudaModelParamLoader.GetInputTensors(model);
var inputNames = inputTensors.Select(x => x.name).ToList();
var inputNames = model.GetInputNames();
Assert.AreEqual(3, inputNames.Count);
Assert.AreEqual(3, inputNames.Count());
Assert.AreEqual(2, BarracudaModelParamLoader.GetNumVisualInputs(model));
Assert.AreEqual(2, model.GetNumVisualInputs());
Assert.AreEqual(0, BarracudaModelParamLoader.GetInputTensors(null).Count);
Assert.AreEqual(0, BarracudaModelParamLoader.GetNumVisualInputs(null));
model = null;
Assert.AreEqual(0, model.GetInputTensors().Count);
Assert.AreEqual(0, model.GetNumVisualInputs());
}
[TestCase(true)]

var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var inputTensors = BarracudaModelParamLoader.GetInputTensors(model);
var inputNames = inputTensors.Select(x => x.name).ToList();
var inputNames = model.GetInputNames();
// Model should contain 2 inputs : recurrent and visual 1
Assert.Contains(TensorNames.VisualObservationPlaceholderPrefix + "0", inputNames);

public void TestGetInputTensorsHybrid()
{
var model = ModelLoader.Load(hybridONNXModel);
var inputTensors = BarracudaModelParamLoader.GetInputTensors(model);
var inputNames = inputTensors.Select(x => x.name).ToList();
var inputNames = model.GetInputNames();
Assert.Contains(TensorNames.VectorObservationPlaceholder, inputNames);
}

{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var outputNames = BarracudaModelParamLoader.GetOutputNames(model);
var outputNames = model.GetOutputNames();
Assert.AreEqual(0, BarracudaModelParamLoader.GetOutputNames(null).Count());
model = null;
Assert.AreEqual(0, model.GetOutputNames().Count());
}
[TestCase(true)]

var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var outputNames = BarracudaModelParamLoader.GetOutputNames(model);
var outputNames = model.GetOutputNames();
var actionOutputName = useDeprecatedNNModel ? TensorNames.ActionOutputDeprecated : TensorNames.DiscreteActionOutput;
Assert.Contains(actionOutputName, outputNames);
// TODO : There are some memory tensors as well

public void TestGetOutputTensorsHybrid()
{
var model = ModelLoader.Load(hybridONNXModel);
var outputNames = BarracudaModelParamLoader.GetOutputNames(model);
var outputNames = model.GetOutputNames();
Assert.AreEqual(0, BarracudaModelParamLoader.GetOutputNames(null).Count());
model = null;
Assert.AreEqual(0, model.GetOutputNames().Count());
}
[TestCase(true)]

正在加载...
取消
保存