using UnityEngine ;
using Unity.Barracuda ;
using System.IO ;
using Unity.Barracuda.ONNX ;
using Unity.MLAgents ;
using Unity.MLAgents.Policies ;
#if UNITY_EDITOR
{
m_OverrideExtension = args [ i + 1 ] . Trim ( ) . ToLower ( ) ;
var isKnownExtension = k_SupportedExtensions . Contains ( m_OverrideExtension ) ;
// Not supported yet - need to update the model loading code to support
var isOnnx = m_OverrideExtension . Equals ( "onnx" ) ;
if ( ! isKnownExtension | | isOnnx )
if ( ! isKnownExtension )
{
Debug . LogError ( $"loading unsupported format: {m_OverrideExtension}" ) ;
Application . Quit ( 1 ) ;
return null ;
}
byte [ ] m odel = null ;
byte [ ] rawM odel = null ;
m odel = File . ReadAllBytes ( assetPath ) ;
rawM odel = File . ReadAllBytes ( assetPath ) ;
}
catch ( IOException )
{
return null ;
}
// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
var asset = ScriptableObject . CreateInstance < NNModel > ( ) ;
asset . modelData = ScriptableObject . CreateInstance < NNModelData > ( ) ;
asset . modelData . Value = model ;
NNModel asset ;
var isOnnx = m_OverrideExtension . Equals ( "onnx" ) ;
if ( isOnnx )
{
var converter = new ONNXModelConverter ( true ) ;
var onnxModel = converter . Convert ( rawModel ) ;
NNModelData assetData = ScriptableObject . CreateInstance < NNModelData > ( ) ;
using ( var memoryStream = new MemoryStream ( ) )
using ( var writer = new BinaryWriter ( memoryStream ) )
{
ModelWriter . Save ( writer , onnxModel ) ;
assetData . Value = memoryStream . ToArray ( ) ;
}
assetData . name = "Data" ;
assetData . hideFlags = HideFlags . HideInHierarchy ;
asset = ScriptableObject . CreateInstance < NNModel > ( ) ;
asset . modelData = assetData ;
}
else
{
// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
asset = ScriptableObject . CreateInstance < NNModel > ( ) ;
asset . modelData = ScriptableObject . CreateInstance < NNModelData > ( ) ;
asset . modelData . Value = rawModel ;
}
asset . name = "Override - " + Path . GetFileName ( assetPath ) ;
m_CachedModels [ behaviorName ] = asset ;