using Unity.Barracuda.ONNX ;
using Unity.MLAgents ;
using Unity.MLAgents.Policies ;
using UnityEngine.SceneManagement ;
#if UNITY_EDITOR
using UnityEditor ;
#endif
/// </summary>
public class ModelOverrider : MonoBehaviour
{
HashSet < string > k_SupportedExtensions = new HashSet < string > { "nn" , "onnx" } ;
static HashSet < string > k_SupportedExtensions = new HashSet < string > { "nn" , "onnx" } ;
const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory" ;
const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension" ;
const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes" ;
public struct OverrideSettings
{
public string BehaviorNameOverrideDirectory ;
public List < string > OverrideExtensions ;
public int MaxEpisodes ;
public int TimeoutSeconds ;
public bool QuitOnLoadFailure ;
public bool HasOverrides ( )
{
return ! string . IsNullOrEmpty ( BehaviorNameOverrideDirectory ) ;
}
}
private OverrideSettings m_OverrideSettings ;
string m_BehaviorNameOverrideDirectory ;
//string m_BehaviorNameOverrideDirectory;
private List < string > m_OverrideExtensions = new List < string > ( ) ;
//private List<string> m_OverrideExtensions = new List<string>();
// Cached loaded NNModels, with the behavior name as the key.
Dictionary < string , NNModel > m_CachedModels = new Dictionary < string , NNModel > ( ) ;
// Will default to 1 if override models are specified, otherwise 0.
int m_MaxEpisodes ;
//int m_MaxEpisodes;
// Deadline - exit if the time exceeds this
DateTime m_Deadline = DateTime . MaxValue ;
int m_PreviousAgentCompletedEpisodes ;
bool m_QuitOnLoadFailure ;
//bool m_QuitOnLoadFailure;
[Tooltip("Debug values to be used in place of the command line for overriding models.")]
public string debugCommandLineOverride ;
get
{
GetAssetPathFromCommandLine ( ) ;
return ! string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) ;
return m_OverrideSettings . HasOverrides ( ) ;
}
}
}
/// <summary>
/// Get the asset path to use from the commandline arguments.
/// </summary>
/// <returns></returns>
void GetAssetPathFromCommandLine ( )
static OverrideSettings GetSettingsFromCommandLine ( string debugCommandLineOverride = null )
var maxEpisodes = 0 ;
var timeoutSeconds = 0 ;
var overrideSettings = new OverrideSettings ( ) ;
overrideSettings . OverrideExtensions = new List < string > ( ) ;
string [ ] commandLineArgsOverride = null ;
if ( ! string . IsNullOrEmpty ( debugCommandLineOverride ) & & Application . isEditor )
{
if ( args [ i ] = = k_CommandLineModelOverrideDirectoryFlag & & i < args . Length - 1 )
{
m_ BehaviorNameOverrideDirectory = args [ i + 1 ] . Trim ( ) ;
overrideSettings . BehaviorNameOverrideDirectory = args [ i + 1 ] . Trim ( ) ;
}
else if ( args [ i ] = = k_CommandLineModelOverrideExtensionFlag & & i < args . Length - 1 )
{
EditorApplication . isPlaying = false ;
#endif
}
m_OverrideExtensions . Add ( overrideExtension ) ;
overrideSettings . OverrideExtensions . Add ( overrideExtension ) ;
Int32 . TryParse ( args [ i + 1 ] , out maxEpisodes ) ;
Int32 . TryParse ( args [ i + 1 ] , out overrideSettings . MaxEpisodes ) ;
Int32 . TryParse ( args [ i + 1 ] , out timeoutSeconds ) ;
Int32 . TryParse ( args [ i + 1 ] , out overrideSettings . TimeoutSeconds ) ;
m_QuitOnLoadFailure = true ;
overrideSettings . QuitOnLoadFailure = true ;
if ( ! string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) )
return overrideSettings ;
}
/// <summary>
/// Get the asset path to use from the commandline arguments.
/// </summary>
/// <returns></returns>
void GetAssetPathFromCommandLine ( )
{
m_OverrideSettings = GetSettingsFromCommandLine ( debugCommandLineOverride ) ;
if ( ! string . IsNullOrEmpty ( m_OverrideSettings . BehaviorNameOverrideDirectory ) )
m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1 ;
Debug . Log ( $"setting m_MaxEpisodes to {maxEpisodes}" ) ;
m_OverrideSettings . MaxEpisodes = m_OverrideSettings . MaxEpisodes > 0 ? m_OverrideSettings . MaxEpisodes : 1 ;
Debug . Log ( $"setting m_MaxEpisodes to {m_OverrideSettings.MaxEpisodes}" ) ;
if ( timeoutSeconds > 0 )
if ( m_OverrideSettings . TimeoutSeconds > 0 )
m_Deadline = DateTime . Now + TimeSpan . FromSeconds ( timeoutSeconds ) ;
Debug . Log ( $"setting deadline to {timeoutSeconds} from now." ) ;
m_Deadline = DateTime . Now + TimeSpan . FromSeconds ( m_OverrideSettings . TimeoutSeconds ) ;
Debug . Log ( $"Setting deadline to {m_OverrideSettings.TimeoutSeconds} seconds from now." ) ;
}
}
void FixedUpdate ( )
{
if ( m_MaxEpisodes > 0 )
if ( m_OverrideSettings . MaxEpisodes > 0 )
if ( TotalCompletedEpisodes > = m_MaxEpisodes & & TotalNumSteps > m_MaxEpisodes * m_Agent . MaxStep )
if ( TotalCompletedEpisodes > = m_OverrideSettings . MaxEpisodes & & TotalNumSteps > m_OverrideSettings . MaxEpisodes * m_Agent . MaxStep )
{
Debug . Log ( $"ModelOverride reached {TotalCompletedEpisodes} episodes and {TotalNumSteps} steps. Exiting." ) ;
Application . Quit ( 0 ) ;
{
Debug . Log (
$"Deadline exceeded. " +
$"{TotalCompletedEpisodes}/{m_MaxEpisodes} episodes and " +
$"{TotalNumSteps}/{m_MaxEpisodes * m_Agent.MaxStep} steps completed. Exiting." ) ;
$"{TotalCompletedEpisodes}/{m_OverrideSettings. MaxEpisodes} episodes and " +
$"{TotalNumSteps}/{m_OverrideSettings. MaxEpisodes * m_Agent.MaxStep} steps completed. Exiting." ) ;
Application . Quit ( 0 ) ;
#if UNITY_EDITOR
EditorApplication . isPlaying = false ;
return m_CachedModels [ behaviorName ] ;
}
if ( string . IsNullOrEmpty ( m_BehaviorNameOverrideDirectory ) )
if ( string . IsNullOrEmpty ( m_OverrideSettings . BehaviorNameOverrideDirectory ) )
{
Debug . Log ( $"No override directory set." ) ;
return null ;
var overrideExtensions = ( m_OverrideExtensions . Count > 0 )
? m_OverrideExtensions . ToArray ( )
var overrideExtensions = ( m_OverrideSettings . OverrideExtensions . Count > 0 )
? m_OverrideSettings . OverrideExtensions . ToArray ( )
: new [ ] { "nn" , "onnx" } ;
byte [ ] rawModel = null ;
{
var assetPath = Path . Combine ( m_BehaviorNameOverrideDirectory , $"{behaviorName}.{overrideExtension}" ) ;
var assetPath = Path . Combine ( m_OverrideSettings . BehaviorNameOverrideDirectory , $"{behaviorName}.{overrideExtension}" ) ;
try
{
rawModel = File . ReadAllBytes ( assetPath ) ;
if ( rawModel = = null )
{
Debug . Log ( $"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}" ) ;
Debug . Log ( $"Couldn't load model file(s) for {behaviorName} in {m_OverrideSettings. BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_OverrideSettings. BehaviorNameOverrideDirectory)}" ) ;
// Cache the null so we don't repeatedly try to load a missing file
m_CachedModels [ behaviorName ] = null ;
return null ;
}
}
if ( ! overrideOk & & m_QuitOnLoadFailure )
if ( ! overrideOk & & m_OverrideSettings . QuitOnLoadFailure )
{
if ( ! string . IsNullOrEmpty ( overrideError ) )
{
#endif
}
}
public static void CheckSceneForModelOverrides ( )
{
var overrideSettings = GetSettingsFromCommandLine ( ) ;
if ( ! overrideSettings . HasOverrides ( ) )
{
// No overrides specified on the commmandline, so don't check the scene.
return ;
}
var overrideComponentsFound = SceneHasModelOverrideComponents ( ) ;
if ( overrideComponentsFound )
{
// Expected override components and found them.
return ;
}
Debug . LogError ( "Model overriding set on command line, but scene contains no ModelOverride components." ) ;
if ( overrideSettings . QuitOnLoadFailure )
{
Application . Quit ( 1 ) ;
#if UNITY_EDITOR
EditorApplication . isPlaying = false ;
#endif
}
}
static bool SceneHasModelOverrideComponents ( )
{
GameObject [ ] allObjects = UnityEngine . Object . FindObjectsOfType < GameObject > ( ) ;
Debug . Log ( $"Found {allObjects.Length} total GameObjects." ) ;
foreach ( var gameObj in allObjects )
{
Debug . Log ( $"Checking GameObject {gameObj.name}." ) ;
var modeloverride = gameObj . GetComponentsInChildren < ModelOverrider > ( ) ;
if ( modeloverride ! = null & & modeloverride . Length > 0 )
{
// Found at least 1 model override.
return true ;
}
}
Debug . Log ( $"No ModelOverriders found." ) ;
return false ;
}
}
}