public class ModelOverrider : MonoBehaviour
{
HashSet < string > k_SupportedExtensions = new HashSet < string > { "nn" , "onnx" } ;
const string k_CommandLineModelOverrideFlag = "--mlagents-override-model" ;
const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory" ;
const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension" ;
const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes" ;
Agent m_Agent ;
// Assets paths to use, with the behavior name as the key.
Dictionary < string , string > m_BehaviorNameOverrides = new Dictionary < string , string > ( ) ;
string m_OverrideExtension = "nn" ;
private List < string > m_OverrideExtensions = new List < string > ( ) ;
// Cached loaded NNModels, with the behavior name as the key.
Dictionary < string , NNModel > m_CachedModels = new Dictionary < string , NNModel > ( ) ;
public bool HasOverrides
{
get { return m_BehaviorNameOverrides . Count > 0 | | ! string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) ; }
get
{
GetAssetPathFromCommandLine ( ) ;
return ! string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) ;
}
}
public static string GetOverrideBehaviorName ( string originalBehaviorName )
/// <returns></returns>
void GetAssetPathFromCommandLine ( )
{
m_BehaviorNameOverrides . Clear ( ) ;
var maxEpisodes = 0 ;
string [ ] commandLineArgsOverride = null ;
if ( ! string . IsNullOrEmpty ( debugCommandLineOverride ) & & Application . isEditor )
var args = commandLineArgsOverride ? ? Environment . GetCommandLineArgs ( ) ;
for ( var i = 0 ; i < args . Length ; i + + )
{
if ( args [ i ] = = k_CommandLineModelOverrideFlag & & i < args . Length - 2 )
{
var key = args [ i + 1 ] . Trim ( ) ;
var value = args [ i + 2 ] . Trim ( ) ;
m_BehaviorNameOverrides [ key ] = value ;
}
else if ( args [ i ] = = k_CommandLineModelOverrideDirectoryFlag & & i < args . Length - 1 )
if ( args [ i ] = = k_CommandLineModelOverrideDirectoryFlag & & i < args . Length - 1 )
m_OverrideExtension = args [ i + 1 ] . Trim ( ) . ToLower ( ) ;
var isKnownExtension = k_SupportedExtensions . Contains ( m_OverrideExtension ) ;
var overrideExtension = args [ i + 1 ] . Trim ( ) . ToLower ( ) ;
var isKnownExtension = k_SupportedExtensions . Contains ( overrideExtension ) ;
Debug . LogError ( $"loading unsupported format: {m_OverrideExtension}" ) ;
Debug . LogError ( $"loading unsupported format: {overrideExtension}" ) ;
m_OverrideExtensions . Add ( overrideExtension ) ;
}
else if ( args [ i ] = = k_CommandLineQuitAfterEpisodesFlag & & i < args . Length - 1 )
{
}
}
if ( HasOverrides )
if ( ! string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) )
{
// If overriding models, set maxEpisodes to 1 or the command line value
m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1 ;
return m_CachedModels [ behaviorName ] ;
}
string assetPath = null ;
if ( m_BehaviorNameOverrides . ContainsKey ( behaviorName ) )
if ( string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) )
assetPath = m_BehaviorNameOverrides [ behaviorName ] ;
}
else if ( ! string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) )
{
assetPath = Path . Combine ( m_BehaviorNameOverrideDirectory , $"{behaviorName}.{m_OverrideExtension}" ) ;
Debug . Log ( $"No override directory set." ) ;
return null ;
if ( string . IsNullOrEmpty ( assetPath ) )
{
Debug . Log ( $"No override for BehaviorName {behaviorName}, and no directory set." ) ;
return null ;
}
// Try the override extensions in order. If they weren't set, try .nn first, then .onnx.
var overrideExtensions = ( m_OverrideExtensions . Count > 0 )
? m_OverrideExtensions . ToArray ( )
: new [ ] { "nn" , "onnx" } ;
try
bool isOnnx = false ;
string assetName = null ;
foreach ( var overrideExtension in overrideExtensions )
rawModel = File . ReadAllBytes ( assetPath ) ;
var assetPath = Path . Combine ( m_BehaviorNameOverrideDirectory , $"{behaviorName}.{overrideExtension}" ) ;
try
{
rawModel = File . ReadAllBytes ( assetPath ) ;
isOnnx = overrideExtension . Equals ( "onnx" ) ;
assetName = "Override - " + Path . GetFileName ( assetPath ) ;
break ;
}
catch ( IOException )
{
// Do nothing - try the next extension, or we'll exit if nothing loaded.
}
catch ( IOException )
if ( rawModel = = null )
Debug . Log ( $"Couldn't load file {assetPath} at full path {Path.GetFullPath(assetPath)}" , this ) ;
Debug . Log ( $"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}" ) ;
NNModel asset ;
var isOnnx = m_OverrideExtension . Equals ( "onnx" ) ;
if ( isOnnx )
{
var converter = new ONNXModelConverter ( true ) ;
var onnxModel = converter . Convert ( rawModel ) ;
var asset = isOnnx ? LoadOnnxModel ( rawModel ) : LoadBarracudaModel ( rawModel ) ;
asset . name = assetName ;
m_CachedModels [ behaviorName ] = asset ;
return asset ;
}
NNModelData assetData = ScriptableObject . CreateInstance < NNModelData > ( ) ;
using ( var memoryStream = new MemoryStream ( ) )
using ( var writer = new BinaryWriter ( memoryStream ) )
{
ModelWriter . Save ( writer , onnxModel ) ;
assetData . Value = memoryStream . ToArray ( ) ;
}
assetData . name = "Data" ;
assetData . hideFlags = HideFlags . HideInHierarchy ;
NNModel LoadBarracudaModel ( byte [ ] rawModel )
{
var asset = ScriptableObject . CreateInstance < NNModel > ( ) ;
asset . modelData = ScriptableObject . CreateInstance < NNModelData > ( ) ;
asset . modelData . Value = rawModel ;
return asset ;
}
asset = ScriptableObject . CreateInstance < NNModel > ( ) ;
asset . modelData = assetData ;
}
else
NNModel LoadOnnxModel ( byte [ ] rawModel )
{
var converter = new ONNXModelConverter ( true ) ;
var onnxModel = converter . Convert ( rawModel ) ;
NNModelData assetData = ScriptableObject . CreateInstance < NNModelData > ( ) ;
using ( var memoryStream = new MemoryStream ( ) )
using ( var writer = new BinaryWriter ( memoryStream ) )
// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
asset = ScriptableObject . CreateInstance < NNModel > ( ) ;
asset . modelData = ScriptableObject . CreateInstance < NNModelData > ( ) ;
asset . modelData . Value = rawModel ;
ModelWriter . Save ( writer , onnxModel ) ;
assetData . Value = memoryStream . ToArray ( ) ;
assetData . name = "Data" ;
assetData . hideFlags = HideFlags . HideInHierarchy ;
asset . name = "Override - " + Path . GetFileName ( assetPath ) ;
m_CachedModels [ behaviorName ] = asset ;
var asset = ScriptableObject . CreateInstance < NNModel > ( ) ;
asset . modelData = assetData ;
/// <summary>
/// Load the NNModel file from the specified path, and give it to the attached agent.