浏览代码

Model override from commandline (#3265)

* WIP model override from commandline

* Agent lazy init, multiple overrides

* MLAgentsExamples namespace

* add model override to 3dball
/asymm-envs
GitHub 5 年前
当前提交
2db09cef
共有 19 个文件被更改,包括 222 次插入86 次删除
  1. 13
      Project/Assets/ML-Agents/Examples/3DBall/Prefabs/3DBall.prefab
  2. 4
      Project/Assets/ML-Agents/Examples/3DBall/Scenes/3DBall.unity
  3. 1
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  4. 2
      Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorArea.cs
  5. 2
      Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs
  6. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/AdjustTrainingTimescale.cs
  7. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/Area.cs
  8. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/CameraFollow.cs
  9. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/FlyCamera.cs
  10. 3
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GroundContact.cs
  11. 3
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/JointDriveController.cs
  12. 99
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs
  13. 2
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/TargetContact.cs
  14. 1
      Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
  15. 25
      com.unity.ml-agents/Runtime/Agent.cs
  16. 4
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  17. 28
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  18. 110
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs
  19. 3
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs.meta

13
Project/Assets/ML-Agents/Examples/3DBall/Prefabs/3DBall.prefab


- component: {fileID: 114368073295828880}
- component: {fileID: 114715123104194396}
- component: {fileID: 1306725529891448089}
- component: {fileID: 1758424554059689351}
m_Layer: 0
m_Name: Agent
m_TagString: Untagged

DecisionPeriod: 5
RepeatAction: 1
offsetStep: 0
--- !u!114 &1758424554059689351
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 1424713891854676}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 3a6da8f78a394c6ab027688eab81e04d, type: 3}
m_Name:
m_EditorClassIdentifier:
--- !u!1 &1533320402322554
GameObject:
m_ObjectHideFlags: 0

4
Project/Assets/ML-Agents/Examples/3DBall/Scenes/3DBall.unity


propertyPath: m_Name
value: 3DBall (1)
objectReference: {fileID: 0}
- target: {fileID: 1321468028730240, guid: cfa81c019162c4e3caf6e2999c6fdf48, type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 4679453577574622, guid: cfa81c019162c4e3caf6e2999c6fdf48, type: 3}
propertyPath: m_LocalPosition.x
value: 9

1
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


using UnityEngine;
using MLAgents;
using MLAgentsExamples;
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class CrawlerAgent : Agent

2
Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorArea.cs


using UnityEngine;
using MLAgents;
using MLAgentsExamples;
public class FoodCollectorArea : Area
{

2
Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidArea.cs


using UnityEngine;
using MLAgents;
using MLAgentsExamples;
public class PyramidArea : Area
{

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/AdjustTrainingTimescale.cs


using UnityEngine;
namespace MLAgents
namespace MLAgentsExamples
{
public class AdjustTrainingTimescale : MonoBehaviour
{

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/Area.cs


using UnityEngine;
namespace MLAgents
namespace MLAgentsExamples
{
public class Area : MonoBehaviour
{

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/CameraFollow.cs


using UnityEngine;
namespace MLAgents
namespace MLAgentsExamples
{
public class CameraFollow : MonoBehaviour
{

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/FlyCamera.cs


using UnityEngine;
namespace MLAgents
namespace MLAgentsExamples
{
public class FlyCamera : MonoBehaviour
{

3
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GroundContact.cs


using UnityEngine;
using MLAgents;
namespace MLAgents
namespace MLAgentsExamples
{
/// <summary>
/// This class contains logic for locomotion agents with joints which might make contact with the ground.

3
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/JointDriveController.cs


using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Serialization;
using MLAgents;
namespace MLAgents
namespace MLAgentsExamples
{
/// <summary>
/// Used to store relevant information for acting and learning for each body part in agent.

99
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ProjectSettingsOverrides.cs


using UnityEngine;
using MLAgents;
public class ProjectSettingsOverrides : MonoBehaviour
namespace MLAgentsExamples
// Original values
float m_OriginalMonitorVerticalOffset;
Vector3 m_OriginalGravity;
float m_OriginalFixedDeltaTime;
float m_OriginalMaximumDeltaTime;
int m_OriginalSolverIterations;
int m_OriginalSolverVelocityIterations;
public class ProjectSettingsOverrides : MonoBehaviour
{
// Original values
float m_OriginalMonitorVerticalOffset;
Vector3 m_OriginalGravity;
float m_OriginalFixedDeltaTime;
float m_OriginalMaximumDeltaTime;
int m_OriginalSolverIterations;
int m_OriginalSolverVelocityIterations;
[Tooltip("Increase or decrease the scene gravity. Use ~3x to make things less floaty")]
public float gravityMultiplier = 1.0f;
[Tooltip("Increase or decrease the scene gravity. Use ~3x to make things less floaty")]
public float gravityMultiplier = 1.0f;
[Header("Display Settings")]
public float monitorVerticalOffset;
[Header("Display Settings")]
public float monitorVerticalOffset;
[Header("Advanced physics settings")]
[Tooltip("The interval in seconds at which physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) are performed.")]
public float fixedDeltaTime = .02f;
[Tooltip("The maximum time a frame can take. Physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) will be performed only for this duration of time per frame.")]
public float maximumDeltaTime = 1.0f / 3.0f;
[Tooltip("Determines how accurately Rigidbody joints and collision contacts are resolved. (default 6). Must be positive.")]
public int solverIterations = 6;
[Tooltip("Affects how accurately the Rigidbody joints and collision contacts are resolved. (default 1). Must be positive.")]
public int solverVelocityIterations = 1;
[Header("Advanced physics settings")]
[Tooltip("The interval in seconds at which physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) are performed.")]
public float fixedDeltaTime = .02f;
[Tooltip("The maximum time a frame can take. Physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) will be performed only for this duration of time per frame.")]
public float maximumDeltaTime = 1.0f / 3.0f;
[Tooltip("Determines how accurately Rigidbody joints and collision contacts are resolved. (default 6). Must be positive.")]
public int solverIterations = 6;
[Tooltip("Affects how accurately the Rigidbody joints and collision contacts are resolved. (default 1). Must be positive.")]
public int solverVelocityIterations = 1;
public void Awake()
{
// Save the original values
m_OriginalMonitorVerticalOffset = Monitor.verticalOffset;
m_OriginalGravity = Physics.gravity;
m_OriginalFixedDeltaTime = Time.fixedDeltaTime;
m_OriginalMaximumDeltaTime = Time.maximumDeltaTime;
m_OriginalSolverIterations = Physics.defaultSolverIterations;
m_OriginalSolverVelocityIterations = Physics.defaultSolverVelocityIterations;
public void Awake()
{
// Save the original values
m_OriginalMonitorVerticalOffset = Monitor.verticalOffset;
m_OriginalGravity = Physics.gravity;
m_OriginalFixedDeltaTime = Time.fixedDeltaTime;
m_OriginalMaximumDeltaTime = Time.maximumDeltaTime;
m_OriginalSolverIterations = Physics.defaultSolverIterations;
m_OriginalSolverVelocityIterations = Physics.defaultSolverVelocityIterations;
// Override
Monitor.verticalOffset = monitorVerticalOffset;
Physics.gravity *= gravityMultiplier;
Time.fixedDeltaTime = fixedDeltaTime;
Time.maximumDeltaTime = maximumDeltaTime;
Physics.defaultSolverIterations = solverIterations;
Physics.defaultSolverVelocityIterations = solverVelocityIterations;
// Override
Monitor.verticalOffset = monitorVerticalOffset;
Physics.gravity *= gravityMultiplier;
Time.fixedDeltaTime = fixedDeltaTime;
Time.maximumDeltaTime = maximumDeltaTime;
Physics.defaultSolverIterations = solverIterations;
Physics.defaultSolverVelocityIterations = solverVelocityIterations;
Academy.Instance.FloatProperties.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
}
Academy.Instance.FloatProperties.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
}
public void OnDestroy()
{
Monitor.verticalOffset = m_OriginalMonitorVerticalOffset;
Physics.gravity = m_OriginalGravity;
Time.fixedDeltaTime = m_OriginalFixedDeltaTime;
Time.maximumDeltaTime = m_OriginalMaximumDeltaTime;
Physics.defaultSolverIterations = m_OriginalSolverIterations;
Physics.defaultSolverVelocityIterations = m_OriginalSolverVelocityIterations;
public void OnDestroy()
{
Monitor.verticalOffset = m_OriginalMonitorVerticalOffset;
Physics.gravity = m_OriginalGravity;
Time.fixedDeltaTime = m_OriginalFixedDeltaTime;
Time.maximumDeltaTime = m_OriginalMaximumDeltaTime;
Physics.defaultSolverIterations = m_OriginalSolverIterations;
Physics.defaultSolverVelocityIterations = m_OriginalSolverVelocityIterations;
}
}
}

2
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/TargetContact.cs


using UnityEngine;
namespace MLAgents
namespace MLAgentsExamples
{
/// <summary>
/// This class contains logic for locomotion agents with joints which might make contact with a target.

1
Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


using UnityEngine;
using MLAgents;
using MLAgentsExamples;
public class WalkerAgent : Agent
{

25
com.unity.ml-agents/Runtime/Agent.cs


/// This Id will be changed every time the Agent resets.
int m_EpisodeId;
/// Whether or not the Agent has been initialized already
bool m_Initialized;
/// Keeps track of the actions that are masked at each step.
ActionMasker m_ActionMasker;

/// becomes enabled or active.
void OnEnable()
{
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
OnEnableHelper();
m_Recorder = GetComponent<DemonstrationRecorder>();
LazyInitialize();
void OnEnableHelper()
public void LazyInitialize()
if (m_Initialized)
{
return;
}
m_Initialized = true;
// Grab the "static" properties for the Agent.
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
m_PolicyFactory = GetComponent<BehaviorParameters>();
m_Recorder = GetComponent<DemonstrationRecorder>();
m_Info = new AgentInfo();
m_Action = new AgentAction();
sensors = new List<ISensor>();

Academy.Instance.AgentAct += AgentStep;
Academy.Instance.AgentForceReset += _AgentReset;
m_PolicyFactory = GetComponent<BehaviorParameters>();
}
/// Monobehavior function that is called when the attached GameObject

}
NotifyAgentDone();
m_Brain?.Dispose();
m_Initialized = false;
}
void NotifyAgentDone(bool maxStepReached = false)

4
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs


var agents = new List<TestAgent> { agentA, agentB };
foreach (var agent in agents)
{
var agentEnableMethod = typeof(Agent).GetMethod("OnEnableHelper",
BindingFlags.Instance | BindingFlags.NonPublic);
agentEnableMethod?.Invoke(agent, new object[] {});
agent.LazyInitialize();
}
agentA.collectObservationsSensor.AddObservation(new Vector3(1, 2, 3));
agentB.collectObservationsSensor.AddObservation(new Vector3(4, 5, 6));

28
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


Assert.AreEqual(0, agent1.agentActionCalls);
Assert.AreEqual(0, agent2.agentActionCalls);
var agentEnableMethod = typeof(Agent).GetMethod("OnEnableHelper",
BindingFlags.Instance | BindingFlags.NonPublic);
agentEnableMethod?.Invoke(agent2, new object[] {});
agentEnableMethod?.Invoke(agent1, new object[] {});
agent2.LazyInitialize();
agent1.LazyInitialize();
// agent1 was not enabled when the academy started
// The agents have been initialized

var aca = Academy.Instance;
var agentEnableMethod = typeof(Agent).GetMethod(
"OnEnableHelper", BindingFlags.Instance | BindingFlags.NonPublic);
var decisionRequester = agent1.gameObject.AddComponent<DecisionRequester>();
decisionRequester.DecisionPeriod = 2;
decisionRequester.Awake();

agentEnableMethod?.Invoke(agent1, new object[] {});
agent1.LazyInitialize();
var numberAgent1Reset = 0;
var numberAgent2Initialization = 0;

//Agent 2 is only initialized at step 2
if (i == 2)
{
agentEnableMethod?.Invoke(agent2, new object[] {});
agent2.LazyInitialize();
numberAgent2Initialization += 1;
}

var aca = Academy.Instance;
var agentEnableMethod = typeof(Agent).GetMethod(
"OnEnableHelper", BindingFlags.Instance | BindingFlags.NonPublic);
agentEnableMethod?.Invoke(agent2, new object[] {});
agent2.LazyInitialize();
var numberAgent1Reset = 0;
var numberAgent2Reset = 0;

//Agent 1 is only initialized at step 2
if (i == 2)
{
agentEnableMethod?.Invoke(agent1, new object[] {});
agent1.LazyInitialize();
}
// Set agent 1 to done every 11 steps to test behavior
if (i % 11 == 5)

var agent2 = agentGo2.GetComponent<TestAgent>();
var aca = Academy.Instance;
var agentEnableMethod = typeof(Agent).GetMethod(
"OnEnableHelper", BindingFlags.Instance | BindingFlags.NonPublic);
var decisionRequester = agent1.gameObject.AddComponent<DecisionRequester>();
decisionRequester.DecisionPeriod = 2;
decisionRequester.Awake();

agentEnableMethod?.Invoke(agent2, new object[] {});
agentEnableMethod?.Invoke(agent1, new object[] {});
agent2.LazyInitialize();
agent1.LazyInitialize();
var j = 0;
for (var i = 0; i < 500; i++)

110
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs


using System;
using System.Collections.Generic;
using UnityEngine;
using Barracuda;
using System.IO;
using MLAgents;
namespace MLAgentsExamples
{
/// <summary>
/// Utility class to allow the NNModel file for an agent to be overriden during inference.
/// This is useful to validate the file after training is done.
/// The behavior name to override and file path are specified on the commandline, e.g.
/// player.exe --mlagents-override-model behavior1 /path/to/model1.nn --mlagents-override-model behavior2 /path/to/model2.nn
/// Note this will only work with example scenes that have 1:1 Agent:Behaviors. More complicated scenes like WallJump
/// probably won't override correctly.
/// </summary>
public class ModelOverrider : MonoBehaviour
{
const string k_CommandLineFlag = "--mlagents-override-model";
// Assets paths to use, with the behavior name as the key.
Dictionary<string, string> m_BehaviorNameOverrides = new Dictionary<string, string>();
// Cached loaded NNModels, with the behavior name as the key.
Dictionary<string, NNModel> m_CachedModels = new Dictionary<string, NNModel>();
/// <summary>
/// Get the asset path to use from the commandline arguments.
/// </summary>
/// <returns></returns>
void GetAssetPathFromCommandLine()
{
m_BehaviorNameOverrides.Clear();
m_BehaviorNameOverrides["3DBall"] = "/Users/chris.elion/code/ml-agents/models/ppo/3DBall.nn"; // TODO REMOVE ME
var args = Environment.GetCommandLineArgs();
for (var i = 0; i < args.Length-2; i++)
{
if (args[i] == k_CommandLineFlag)
{
var key = args[i + 1].Trim();
var value = args[i + 2].Trim();
m_BehaviorNameOverrides[key] = value;
}
}
}
void OnEnable()
{
GetAssetPathFromCommandLine();
if (m_BehaviorNameOverrides.Count > 0)
{
OverrideModel();
}
}
NNModel GetModelForBehaviorName(string behaviorName)
{
if (m_CachedModels.ContainsKey(behaviorName))
{
return m_CachedModels[behaviorName];
}
if (!m_BehaviorNameOverrides.ContainsKey(behaviorName))
{
Debug.Log($"No override for behaviorName {behaviorName}");
return null;
}
var assetPath = m_BehaviorNameOverrides[behaviorName];
byte[] model = null;
try
{
model = File.ReadAllBytes(assetPath);
}
catch(IOException)
{
Debug.Log($"Couldn't load file {assetPath}", this);
// Cache the null so we don't repeatedly try to load a missing file
m_CachedModels[behaviorName] = null;
return null;
}
var asset = ScriptableObject.CreateInstance<NNModel>();
asset.Value = model;
asset.name = "Override - " + Path.GetFileName(assetPath);
m_CachedModels[behaviorName] = asset;
return asset;
}
/// <summary>
/// Load the NNModel file from the specified path, and give it to the attached agent.
/// </summary>
void OverrideModel()
{
var agent = GetComponent<Agent>();
agent.LazyInitialize();
var bp = agent.GetComponent<BehaviorParameters>();
var behaviorNameAndTeamId = bp.behaviorName;
var behaviorName = behaviorNameAndTeamId.Split('?')[0];
var nnModel = GetModelForBehaviorName(behaviorName);
Debug.Log($"Overriding behavior {behaviorName} for agent with model {nnModel?.name}");
// This might give a null model; that's better because we'll fall back to the Heuristic
agent.GiveModel($"Override_{behaviorName}", nnModel, InferenceDevice.CPU);
}
}
}

3
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs.meta


fileFormatVersion: 2
guid: 3a6da8f78a394c6ab027688eab81e04d
timeCreated: 1579651041
正在加载...
取消
保存