浏览代码

load onnx files for testing (#4208)

* load onnx files for testing

* changelog

* update interface
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
fcbc47b5
共有 4 个文件被更改,包括 36 次插入16 次删除
  1. 42
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs
  2. 4
      com.unity.ml-agents/CHANGELOG.md
  3. 2
      com.unity.ml-agents/package.json
  4. 4
      ml-agents/tests/yamato/training_int_tests.py

42
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs


using UnityEngine;
using Unity.Barracuda;
using System.IO;
using Unity.Barracuda.ONNX;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
#if UNITY_EDITOR

{
m_OverrideExtension = args[i + 1].Trim().ToLower();
var isKnownExtension = k_SupportedExtensions.Contains(m_OverrideExtension);
// Not supported yet - need to update the model loading code to support
var isOnnx = m_OverrideExtension.Equals("onnx");
if (!isKnownExtension || isOnnx)
if (!isKnownExtension)
{
Debug.LogError($"loading unsupported format: {m_OverrideExtension}");
Application.Quit(1);

return null;
}
byte[] model = null;
byte[] rawModel = null;
model = File.ReadAllBytes(assetPath);
rawModel = File.ReadAllBytes(assetPath);
}
catch (IOException)
{

return null;
}
// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
var asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
asset.modelData.Value = model;
NNModel asset;
var isOnnx = m_OverrideExtension.Equals("onnx");
if (isOnnx)
{
var converter = new ONNXModelConverter(true);
var onnxModel = converter.Convert(rawModel);
NNModelData assetData = ScriptableObject.CreateInstance<NNModelData>();
using (var memoryStream = new MemoryStream())
using (var writer = new BinaryWriter(memoryStream))
{
ModelWriter.Save(writer, onnxModel);
assetData.Value = memoryStream.ToArray();
}
assetData.name = "Data";
assetData.hideFlags = HideFlags.HideInHierarchy;
asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = assetData;
}
else
{
// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
asset.modelData.Value = rawModel;
}
asset.name = "Override - " + Path.GetFileName(assetPath);
m_CachedModels[behaviorName] = asset;

4
com.unity.ml-agents/CHANGELOG.md


### Minor Changes
#### com.unity.ml-agents (C#)
- Update Barracuda to 1.0.2.
- Enabled C# formatting using `dotnet-format`.
- Update Barracuda to 1.1.0-preview (#4208)
- Enabled C# formatting using `dotnet-format`. (#4362)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Experimental PyTorch support has been added. Use `--torch` when running `mlagents-learn`, or add
`framework: pytorch` to your trainer configuration (under the behavior name) to enable it.

2
com.unity.ml-agents/package.json


"unity": "2018.4",
"description": "Use state-of-the-art machine learning to create intelligent character behaviors in any Unity environment (games, robotics, film, etc.).",
"dependencies": {
"com.unity.barracuda": "1.0.2",
"com.unity.barracuda": "1.1.0-preview",
"com.unity.modules.imageconversion": "1.0.0",
"com.unity.modules.jsonserialize": "1.0.0",
"com.unity.modules.physics": "1.0.0",

4
ml-agents/tests/yamato/training_int_tests.py


if csharp_version is None and python_version is None:
# Use abs path so that loading doesn't get confused
model_path = os.path.abspath(os.path.dirname(nn_file_expected))
# Onnx loading for overrides not currently supported, but this is
# where to add it in when it is.
for extension in ["nn"]:
for extension in ["nn", "onnx"]:
inference_ok = run_inference(env_path, model_path, extension)
if not inference_ok:
return False

正在加载...
取消
保存