浏览代码
Model override from commandline (#3265)
Model override from commandline (#3265)
* WIP model override from commandline * Agent lazy init, multiple overrides * MLAgentsExamples namespace * add model override to 3dball/asymm-envs
GitHub
5 年前
当前提交
2db09cef
共有 19 个文件被更改,包括 222 次插入 和 86 次删除
-
13Project/Assets/ML-Agents/Examples/3DBall/Prefabs/3DBall.prefab
-
4Project/Assets/ML-Agents/Examples/3DBall/Scenes/3DBall.unity
-
1Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
-
2Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorArea.cs
-
2Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/AdjustTrainingTimescale.cs
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/Area.cs
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/CameraFollow.cs
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/FlyCamera.cs
-
3Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GroundContact.cs
-
3Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/JointDriveController.cs
-
99Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/TargetContact.cs
-
1Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
-
25com.unity.ml-agents/Runtime/Agent.cs
-
4com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
-
28com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
-
110Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs
-
3Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs.meta
|
|||
using UnityEngine; |
|||
using MLAgents; |
|||
|
|||
public class ProjectSettingsOverrides : MonoBehaviour |
|||
namespace MLAgentsExamples |
|||
// Original values
|
|||
float m_OriginalMonitorVerticalOffset; |
|||
Vector3 m_OriginalGravity; |
|||
float m_OriginalFixedDeltaTime; |
|||
float m_OriginalMaximumDeltaTime; |
|||
int m_OriginalSolverIterations; |
|||
int m_OriginalSolverVelocityIterations; |
|||
public class ProjectSettingsOverrides : MonoBehaviour |
|||
{ |
|||
// Original values
|
|||
float m_OriginalMonitorVerticalOffset; |
|||
Vector3 m_OriginalGravity; |
|||
float m_OriginalFixedDeltaTime; |
|||
float m_OriginalMaximumDeltaTime; |
|||
int m_OriginalSolverIterations; |
|||
int m_OriginalSolverVelocityIterations; |
|||
[Tooltip("Increase or decrease the scene gravity. Use ~3x to make things less floaty")] |
|||
public float gravityMultiplier = 1.0f; |
|||
[Tooltip("Increase or decrease the scene gravity. Use ~3x to make things less floaty")] |
|||
public float gravityMultiplier = 1.0f; |
|||
[Header("Display Settings")] |
|||
public float monitorVerticalOffset; |
|||
[Header("Display Settings")] |
|||
public float monitorVerticalOffset; |
|||
[Header("Advanced physics settings")] |
|||
[Tooltip("The interval in seconds at which physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) are performed.")] |
|||
public float fixedDeltaTime = .02f; |
|||
[Tooltip("The maximum time a frame can take. Physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) will be performed only for this duration of time per frame.")] |
|||
public float maximumDeltaTime = 1.0f / 3.0f; |
|||
[Tooltip("Determines how accurately Rigidbody joints and collision contacts are resolved. (default 6). Must be positive.")] |
|||
public int solverIterations = 6; |
|||
[Tooltip("Affects how accurately the Rigidbody joints and collision contacts are resolved. (default 1). Must be positive.")] |
|||
public int solverVelocityIterations = 1; |
|||
[Header("Advanced physics settings")] |
|||
[Tooltip("The interval in seconds at which physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) are performed.")] |
|||
public float fixedDeltaTime = .02f; |
|||
[Tooltip("The maximum time a frame can take. Physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) will be performed only for this duration of time per frame.")] |
|||
public float maximumDeltaTime = 1.0f / 3.0f; |
|||
[Tooltip("Determines how accurately Rigidbody joints and collision contacts are resolved. (default 6). Must be positive.")] |
|||
public int solverIterations = 6; |
|||
[Tooltip("Affects how accurately the Rigidbody joints and collision contacts are resolved. (default 1). Must be positive.")] |
|||
public int solverVelocityIterations = 1; |
|||
public void Awake() |
|||
{ |
|||
// Save the original values
|
|||
m_OriginalMonitorVerticalOffset = Monitor.verticalOffset; |
|||
m_OriginalGravity = Physics.gravity; |
|||
m_OriginalFixedDeltaTime = Time.fixedDeltaTime; |
|||
m_OriginalMaximumDeltaTime = Time.maximumDeltaTime; |
|||
m_OriginalSolverIterations = Physics.defaultSolverIterations; |
|||
m_OriginalSolverVelocityIterations = Physics.defaultSolverVelocityIterations; |
|||
public void Awake() |
|||
{ |
|||
// Save the original values
|
|||
m_OriginalMonitorVerticalOffset = Monitor.verticalOffset; |
|||
m_OriginalGravity = Physics.gravity; |
|||
m_OriginalFixedDeltaTime = Time.fixedDeltaTime; |
|||
m_OriginalMaximumDeltaTime = Time.maximumDeltaTime; |
|||
m_OriginalSolverIterations = Physics.defaultSolverIterations; |
|||
m_OriginalSolverVelocityIterations = Physics.defaultSolverVelocityIterations; |
|||
// Override
|
|||
Monitor.verticalOffset = monitorVerticalOffset; |
|||
Physics.gravity *= gravityMultiplier; |
|||
Time.fixedDeltaTime = fixedDeltaTime; |
|||
Time.maximumDeltaTime = maximumDeltaTime; |
|||
Physics.defaultSolverIterations = solverIterations; |
|||
Physics.defaultSolverVelocityIterations = solverVelocityIterations; |
|||
// Override
|
|||
Monitor.verticalOffset = monitorVerticalOffset; |
|||
Physics.gravity *= gravityMultiplier; |
|||
Time.fixedDeltaTime = fixedDeltaTime; |
|||
Time.maximumDeltaTime = maximumDeltaTime; |
|||
Physics.defaultSolverIterations = solverIterations; |
|||
Physics.defaultSolverVelocityIterations = solverVelocityIterations; |
|||
Academy.Instance.FloatProperties.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); }); |
|||
} |
|||
Academy.Instance.FloatProperties.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); }); |
|||
} |
|||
public void OnDestroy() |
|||
{ |
|||
Monitor.verticalOffset = m_OriginalMonitorVerticalOffset; |
|||
Physics.gravity = m_OriginalGravity; |
|||
Time.fixedDeltaTime = m_OriginalFixedDeltaTime; |
|||
Time.maximumDeltaTime = m_OriginalMaximumDeltaTime; |
|||
Physics.defaultSolverIterations = m_OriginalSolverIterations; |
|||
Physics.defaultSolverVelocityIterations = m_OriginalSolverVelocityIterations; |
|||
public void OnDestroy() |
|||
{ |
|||
Monitor.verticalOffset = m_OriginalMonitorVerticalOffset; |
|||
Physics.gravity = m_OriginalGravity; |
|||
Time.fixedDeltaTime = m_OriginalFixedDeltaTime; |
|||
Time.maximumDeltaTime = m_OriginalMaximumDeltaTime; |
|||
Physics.defaultSolverIterations = m_OriginalSolverIterations; |
|||
Physics.defaultSolverVelocityIterations = m_OriginalSolverVelocityIterations; |
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using Barracuda; |
|||
using System.IO; |
|||
using MLAgents; |
|||
|
|||
namespace MLAgentsExamples |
|||
{ |
|||
/// <summary>
|
|||
/// Utility class to allow the NNModel file for an agent to be overriden during inference.
|
|||
/// This is useful to validate the file after training is done.
|
|||
/// The behavior name to override and file path are specified on the commandline, e.g.
|
|||
/// player.exe --mlagents-override-model behavior1 /path/to/model1.nn --mlagents-override-model behavior2 /path/to/model2.nn
|
|||
/// Note this will only work with example scenes that have 1:1 Agent:Behaviors. More complicated scenes like WallJump
|
|||
/// probably won't override correctly.
|
|||
/// </summary>
|
|||
public class ModelOverrider : MonoBehaviour |
|||
{ |
|||
const string k_CommandLineFlag = "--mlagents-override-model"; |
|||
// Assets paths to use, with the behavior name as the key.
|
|||
Dictionary<string, string> m_BehaviorNameOverrides = new Dictionary<string, string>(); |
|||
|
|||
// Cached loaded NNModels, with the behavior name as the key.
|
|||
Dictionary<string, NNModel> m_CachedModels = new Dictionary<string, NNModel>(); |
|||
|
|||
/// <summary>
|
|||
/// Get the asset path to use from the commandline arguments.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
void GetAssetPathFromCommandLine() |
|||
{ |
|||
m_BehaviorNameOverrides.Clear(); |
|||
m_BehaviorNameOverrides["3DBall"] = "/Users/chris.elion/code/ml-agents/models/ppo/3DBall.nn"; // TODO REMOVE ME
|
|||
|
|||
var args = Environment.GetCommandLineArgs(); |
|||
for (var i = 0; i < args.Length-2; i++) |
|||
{ |
|||
if (args[i] == k_CommandLineFlag) |
|||
{ |
|||
var key = args[i + 1].Trim(); |
|||
var value = args[i + 2].Trim(); |
|||
m_BehaviorNameOverrides[key] = value; |
|||
} |
|||
} |
|||
} |
|||
|
|||
void OnEnable() |
|||
{ |
|||
GetAssetPathFromCommandLine(); |
|||
if (m_BehaviorNameOverrides.Count > 0) |
|||
{ |
|||
OverrideModel(); |
|||
} |
|||
} |
|||
|
|||
NNModel GetModelForBehaviorName(string behaviorName) |
|||
{ |
|||
if (m_CachedModels.ContainsKey(behaviorName)) |
|||
{ |
|||
return m_CachedModels[behaviorName]; |
|||
} |
|||
|
|||
if (!m_BehaviorNameOverrides.ContainsKey(behaviorName)) |
|||
{ |
|||
Debug.Log($"No override for behaviorName {behaviorName}"); |
|||
return null; |
|||
} |
|||
|
|||
var assetPath = m_BehaviorNameOverrides[behaviorName]; |
|||
|
|||
byte[] model = null; |
|||
try |
|||
{ |
|||
model = File.ReadAllBytes(assetPath); |
|||
} |
|||
catch(IOException) |
|||
{ |
|||
Debug.Log($"Couldn't load file {assetPath}", this); |
|||
// Cache the null so we don't repeatedly try to load a missing file
|
|||
m_CachedModels[behaviorName] = null; |
|||
return null; |
|||
} |
|||
|
|||
var asset = ScriptableObject.CreateInstance<NNModel>(); |
|||
asset.Value = model; |
|||
asset.name = "Override - " + Path.GetFileName(assetPath); |
|||
m_CachedModels[behaviorName] = asset; |
|||
return asset; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Load the NNModel file from the specified path, and give it to the attached agent.
|
|||
/// </summary>
|
|||
void OverrideModel() |
|||
{ |
|||
var agent = GetComponent<Agent>(); |
|||
agent.LazyInitialize(); |
|||
var bp = agent.GetComponent<BehaviorParameters>(); |
|||
|
|||
var behaviorNameAndTeamId = bp.behaviorName; |
|||
var behaviorName = behaviorNameAndTeamId.Split('?')[0]; |
|||
var nnModel = GetModelForBehaviorName(behaviorName); |
|||
Debug.Log($"Overriding behavior {behaviorName} for agent with model {nnModel?.name}"); |
|||
// This might give a null model; that's better because we'll fall back to the Heuristic
|
|||
agent.GiveModel($"Override_{behaviorName}", nnModel, InferenceDevice.CPU); |
|||
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 3a6da8f78a394c6ab027688eab81e04d |
|||
timeCreated: 1579651041 |
撰写
预览
正在加载...
取消
保存
Reference in new issue