浏览代码

WallJump - handle Agent starting before ModelOverrider (#4502)

* handle Agent starting before ModelOverrider
* allow multiple override types
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
f5bb26d2
共有 1 个文件被更改,包括 64 次插入61 次删除
  1. 125
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs

125
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs


public class ModelOverrider : MonoBehaviour
{
HashSet<string> k_SupportedExtensions = new HashSet<string> { "nn", "onnx" };
const string k_CommandLineModelOverrideFlag = "--mlagents-override-model";
const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory";
const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension";
const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes";

Agent m_Agent;
// Assets paths to use, with the behavior name as the key.
Dictionary<string, string> m_BehaviorNameOverrides = new Dictionary<string, string>();
string m_OverrideExtension = "nn";
private List<string> m_OverrideExtensions = new List<string>();
// Cached loaded NNModels, with the behavior name as the key.
Dictionary<string, NNModel> m_CachedModels = new Dictionary<string, NNModel>();

public bool HasOverrides
{
get { return m_BehaviorNameOverrides.Count > 0 || !string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory); }
get
{
GetAssetPathFromCommandLine();
return !string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory);
}
}
public static string GetOverrideBehaviorName(string originalBehaviorName)

/// <returns></returns>
void GetAssetPathFromCommandLine()
{
m_BehaviorNameOverrides.Clear();
var maxEpisodes = 0;
string[] commandLineArgsOverride = null;
if (!string.IsNullOrEmpty(debugCommandLineOverride) && Application.isEditor)

var args = commandLineArgsOverride ?? Environment.GetCommandLineArgs();
for (var i = 0; i < args.Length; i++)
{
if (args[i] == k_CommandLineModelOverrideFlag && i < args.Length - 2)
{
var key = args[i + 1].Trim();
var value = args[i + 2].Trim();
m_BehaviorNameOverrides[key] = value;
}
else if (args[i] == k_CommandLineModelOverrideDirectoryFlag && i < args.Length - 1)
if (args[i] == k_CommandLineModelOverrideDirectoryFlag && i < args.Length - 1)
m_OverrideExtension = args[i + 1].Trim().ToLower();
var isKnownExtension = k_SupportedExtensions.Contains(m_OverrideExtension);
var overrideExtension = args[i + 1].Trim().ToLower();
var isKnownExtension = k_SupportedExtensions.Contains(overrideExtension);
Debug.LogError($"loading unsupported format: {m_OverrideExtension}");
Debug.LogError($"loading unsupported format: {overrideExtension}");
m_OverrideExtensions.Add(overrideExtension);
}
else if (args[i] == k_CommandLineQuitAfterEpisodesFlag && i < args.Length - 1)
{

}
}
if (HasOverrides)
if (!string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory))
{
// If overriding models, set maxEpisodes to 1 or the command line value
m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1;

return m_CachedModels[behaviorName];
}
string assetPath = null;
if (m_BehaviorNameOverrides.ContainsKey(behaviorName))
if (string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory))
assetPath = m_BehaviorNameOverrides[behaviorName];
}
else if (!string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory))
{
assetPath = Path.Combine(m_BehaviorNameOverrideDirectory, $"{behaviorName}.{m_OverrideExtension}");
Debug.Log($"No override directory set.");
return null;
if (string.IsNullOrEmpty(assetPath))
{
Debug.Log($"No override for BehaviorName {behaviorName}, and no directory set.");
return null;
}
// Try the override extensions in order. If they weren't set, try .nn first, then .onnx.
var overrideExtensions = (m_OverrideExtensions.Count > 0)
? m_OverrideExtensions.ToArray()
: new[] { "nn", "onnx" };
try
bool isOnnx = false;
string assetName = null;
foreach (var overrideExtension in overrideExtensions)
rawModel = File.ReadAllBytes(assetPath);
var assetPath = Path.Combine(m_BehaviorNameOverrideDirectory, $"{behaviorName}.{overrideExtension}");
try
{
rawModel = File.ReadAllBytes(assetPath);
isOnnx = overrideExtension.Equals("onnx");
assetName = "Override - " + Path.GetFileName(assetPath);
break;
}
catch (IOException)
{
// Do nothing - try the next extension, or we'll exit if nothing loaded.
}
catch (IOException)
if (rawModel == null)
Debug.Log($"Couldn't load file {assetPath} at full path {Path.GetFullPath(assetPath)}", this);
Debug.Log($"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}");
NNModel asset;
var isOnnx = m_OverrideExtension.Equals("onnx");
if (isOnnx)
{
var converter = new ONNXModelConverter(true);
var onnxModel = converter.Convert(rawModel);
var asset = isOnnx ? LoadOnnxModel(rawModel) : LoadBarracudaModel(rawModel);
asset.name = assetName;
m_CachedModels[behaviorName] = asset;
return asset;
}
NNModelData assetData = ScriptableObject.CreateInstance<NNModelData>();
using (var memoryStream = new MemoryStream())
using (var writer = new BinaryWriter(memoryStream))
{
ModelWriter.Save(writer, onnxModel);
assetData.Value = memoryStream.ToArray();
}
assetData.name = "Data";
assetData.hideFlags = HideFlags.HideInHierarchy;
NNModel LoadBarracudaModel(byte[] rawModel)
{
var asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
asset.modelData.Value = rawModel;
return asset;
}
asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = assetData;
}
else
NNModel LoadOnnxModel(byte[] rawModel)
{
var converter = new ONNXModelConverter(true);
var onnxModel = converter.Convert(rawModel);
NNModelData assetData = ScriptableObject.CreateInstance<NNModelData>();
using (var memoryStream = new MemoryStream())
using (var writer = new BinaryWriter(memoryStream))
// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
asset.modelData.Value = rawModel;
ModelWriter.Save(writer, onnxModel);
assetData.Value = memoryStream.ToArray();
assetData.name = "Data";
assetData.hideFlags = HideFlags.HideInHierarchy;
asset.name = "Override - " + Path.GetFileName(assetPath);
m_CachedModels[behaviorName] = asset;
var asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = assetData;
/// <summary>
/// Load the NNModel file from the specified path, and give it to the attached agent.

正在加载...
取消
保存