Ruo-Ping Dong
4 年前
当前提交
80fe3d96
共有 10 个文件被更改,包括 485 次插入 和 1 次删除
-
14com.unity.ml-agents/Runtime/Academy.cs
-
9com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
-
200com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs
-
11com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs.meta
-
95com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs
-
11com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs.meta
-
62com.unity.ml-agents/Runtime/ReplayBuffer.cs
-
11com.unity.ml-agents/Runtime/ReplayBuffer.cs.meta
-
62com.unity.ml-agents/Runtime/Trainer.cs
-
11com.unity.ml-agents/Runtime/Trainer.cs.meta
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using Unity.Barracuda; |
|||
using UnityEngine.Profiling; |
|||
using Unity.MLAgents.Actuators; |
|||
using Unity.MLAgents.Inference; |
|||
using Unity.MLAgents.Policies; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
internal class TrainingModelRunner |
|||
{ |
|||
List<AgentInfoSensorsPair> m_Infos = new List<AgentInfoSensorsPair>(); |
|||
Dictionary<int, ActionBuffers> m_LastActionsReceived = new Dictionary<int, ActionBuffers>(); |
|||
List<int> m_OrderedAgentsRequestingDecisions = new List<int>(); |
|||
|
|||
ITensorAllocator m_TensorAllocator; |
|||
TensorGenerator m_TensorGenerator; |
|||
TensorApplier m_TensorApplier; |
|||
|
|||
NNModel m_Model; |
|||
NNModel m_TargetModel; |
|||
string m_ModelName; |
|||
InferenceDevice m_InferenceDevice; |
|||
IWorker m_Engine; |
|||
bool m_Verbose = false; |
|||
string[] m_OutputNames; |
|||
IReadOnlyList<TensorProxy> m_InferenceInputs; |
|||
List<TensorProxy> m_InferenceOutputs; |
|||
Dictionary<string, Tensor> m_InputsByName; |
|||
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>(); |
|||
|
|||
SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator(); |
|||
|
|||
bool m_ObservationsInitialized; |
|||
|
|||
/// <summary>
|
|||
/// Initializes the Brain with the Model that it will use when selecting actions for
|
|||
/// the agents
|
|||
/// </summary>
|
|||
/// <param name="model"> The Barracuda model to load </param>
|
|||
/// <param name="actionSpec"> Description of the actions for the Agent.</param>
|
|||
/// <param name="inferenceDevice"> Inference execution device. CPU is the fastest
|
|||
/// option for most of ML Agents models. </param>
|
|||
/// <param name="seed"> The seed that will be used to initialize the RandomNormal
|
|||
/// and Multinomial objects used when running inference.</param>
|
|||
/// <exception cref="UnityAgentsException">Throws an error when the model is null
|
|||
/// </exception>
|
|||
public TrainingModelRunner( |
|||
ActionSpec actionSpec, |
|||
int seed = 0) |
|||
{ |
|||
Model barracudaModel; |
|||
m_TensorAllocator = new TensorCachingAllocator(); |
|||
|
|||
// barracudaModel = Barracuda.SomeModelBuilder.CreateModel();
|
|||
barracudaModel = ModelLoader.Load(new NNModel()); |
|||
WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharpBurst; |
|||
m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); |
|||
|
|||
|
|||
m_InferenceInputs = barracudaModel.GetInputTensors(); |
|||
m_OutputNames = barracudaModel.GetOutputNames(); |
|||
m_TensorGenerator = new TensorGenerator( |
|||
seed, m_TensorAllocator, m_Memories, barracudaModel); |
|||
m_TensorApplier = new TensorApplier( |
|||
actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel); |
|||
m_InputsByName = new Dictionary<string, Tensor>(); |
|||
m_InferenceOutputs = new List<TensorProxy>(); |
|||
} |
|||
|
|||
public InferenceDevice InferenceDevice |
|||
{ |
|||
get { return m_InferenceDevice; } |
|||
} |
|||
|
|||
public NNModel Model |
|||
{ |
|||
get { return m_Model; } |
|||
} |
|||
|
|||
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs) |
|||
{ |
|||
m_InputsByName.Clear(); |
|||
for (var i = 0; i < infInputs.Count; i++) |
|||
{ |
|||
var inp = infInputs[i]; |
|||
m_InputsByName[inp.name] = inp.data; |
|||
} |
|||
} |
|||
|
|||
public void Dispose() |
|||
{ |
|||
if (m_Engine != null) |
|||
m_Engine.Dispose(); |
|||
m_TensorAllocator?.Reset(false); |
|||
} |
|||
|
|||
void FetchBarracudaOutputs(string[] names) |
|||
{ |
|||
m_InferenceOutputs.Clear(); |
|||
foreach (var n in names) |
|||
{ |
|||
var output = m_Engine.PeekOutput(n); |
|||
m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n)); |
|||
} |
|||
} |
|||
|
|||
public void PutObservations(AgentInfo info, List<ISensor> sensors) |
|||
{ |
|||
#if DEBUG
|
|||
m_SensorShapeValidator.ValidateSensors(sensors); |
|||
#endif
|
|||
m_Infos.Add(new AgentInfoSensorsPair |
|||
{ |
|||
agentInfo = info, |
|||
sensors = sensors |
|||
}); |
|||
|
|||
// We add the episodeId to this list to maintain the order in which the decisions were requested
|
|||
m_OrderedAgentsRequestingDecisions.Add(info.episodeId); |
|||
|
|||
if (!m_LastActionsReceived.ContainsKey(info.episodeId)) |
|||
{ |
|||
m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty; |
|||
} |
|||
if (info.done) |
|||
{ |
|||
// If the agent is done, we remove the key from the last action dictionary since no action
|
|||
// should be taken.
|
|||
m_LastActionsReceived.Remove(info.episodeId); |
|||
} |
|||
} |
|||
|
|||
public void DecideBatch() |
|||
{ |
|||
var currentBatchSize = m_Infos.Count; |
|||
if (currentBatchSize == 0) |
|||
{ |
|||
return; |
|||
} |
|||
if (!m_ObservationsInitialized) |
|||
{ |
|||
// Just grab the first agent in the collection (any will suffice, really).
|
|||
// We check for an empty Collection above, so this will always return successfully.
|
|||
var firstInfo = m_Infos[0]; |
|||
m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator); |
|||
m_ObservationsInitialized = true; |
|||
} |
|||
|
|||
Profiler.BeginSample("ModelRunner.DecideAction"); |
|||
Profiler.BeginSample(m_ModelName); |
|||
|
|||
Profiler.BeginSample($"GenerateTensors"); |
|||
// Prepare the input tensors to be feed into the engine
|
|||
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos); |
|||
Profiler.EndSample(); |
|||
|
|||
Profiler.BeginSample($"PrepareBarracudaInputs"); |
|||
PrepareBarracudaInputs(m_InferenceInputs); |
|||
Profiler.EndSample(); |
|||
|
|||
// Execute the Model
|
|||
Profiler.BeginSample($"ExecuteGraph"); |
|||
m_Engine.Execute(m_InputsByName); |
|||
Profiler.EndSample(); |
|||
|
|||
Profiler.BeginSample($"FetchBarracudaOutputs"); |
|||
FetchBarracudaOutputs(m_OutputNames); |
|||
Profiler.EndSample(); |
|||
|
|||
Profiler.BeginSample($"ApplyTensors"); |
|||
// Update the outputs
|
|||
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); |
|||
Profiler.EndSample(); |
|||
|
|||
Profiler.EndSample(); // end name
|
|||
Profiler.EndSample(); // end ModelRunner.DecideAction
|
|||
|
|||
m_Infos.Clear(); |
|||
|
|||
m_OrderedAgentsRequestingDecisions.Clear(); |
|||
} |
|||
|
|||
public bool HasModel(NNModel other, InferenceDevice otherInferenceDevice) |
|||
{ |
|||
return m_Model == other && m_InferenceDevice == otherInferenceDevice; |
|||
} |
|||
|
|||
public ActionBuffers GetAction(int agentId) |
|||
{ |
|||
if (m_LastActionsReceived.ContainsKey(agentId)) |
|||
{ |
|||
return m_LastActionsReceived[agentId]; |
|||
} |
|||
return ActionBuffers.Empty; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 03ace8815cd804ee994a5068f618b845 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
// Policy for C# training
|
|||
|
|||
using Unity.Barracuda; |
|||
using System.Collections.Generic; |
|||
using Unity.MLAgents.Actuators; |
|||
using Unity.MLAgents.Inference; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Policies |
|||
{ |
|||
/// <summary>
|
|||
/// The Barracuda Policy uses a Barracuda Model to make decisions at
|
|||
/// every step. It uses a ModelRunner that is shared across all
|
|||
/// Barracuda Policies that use the same model and inference devices.
|
|||
/// </summary>
|
|||
internal class TrainingPolicy : IPolicy |
|||
{ |
|||
protected TrainingModelRunner m_ModelRunner; |
|||
ActionBuffers m_LastActionBuffer; |
|||
|
|||
int m_AgentId; |
|||
|
|||
/// <summary>
|
|||
/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
|
|||
/// </summary>
|
|||
List<int[]> m_SensorShapes; |
|||
ActionSpec m_ActionSpec; |
|||
|
|||
private string m_BehaviorName; |
|||
|
|||
/// <summary>
|
|||
/// List of actuators, only used for analytics
|
|||
/// </summary>
|
|||
private IList<IActuator> m_Actuators; |
|||
|
|||
/// <summary>
|
|||
/// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy,
|
|||
/// and do additional deduplication in the analytics code.
|
|||
/// </summary>
|
|||
private bool m_AnalyticsSent; |
|||
|
|||
private AgentInfo m_LastInfo; |
|||
|
|||
/// <inheritdoc />
|
|||
public TrainingPolicy( |
|||
ActionSpec actionSpec, |
|||
IList<IActuator> actuators, |
|||
string behaviorName |
|||
) |
|||
{ |
|||
m_ModelRunner = Academy.Instance.GetOrCreateTrainingModelRunner(behaviorName, actionSpec); |
|||
m_BehaviorName = behaviorName; |
|||
m_ActionSpec = actionSpec; |
|||
m_Actuators = actuators; |
|||
} |
|||
|
|||
/// <inheritdoc />
|
|||
public void RequestDecision(AgentInfo info, List<ISensor> sensors) |
|||
{ |
|||
if (!m_AnalyticsSent) |
|||
{ |
|||
m_AnalyticsSent = true; |
|||
Analytics.InferenceAnalytics.InferenceModelSet( |
|||
m_ModelRunner.Model, |
|||
m_BehaviorName, |
|||
m_ModelRunner.InferenceDevice, |
|||
sensors, |
|||
m_ActionSpec, |
|||
m_Actuators |
|||
); |
|||
} |
|||
m_AgentId = info.episodeId; |
|||
m_ModelRunner?.PutObservations(info, sensors); |
|||
} |
|||
|
|||
/// <inheritdoc />
|
|||
public ref readonly ActionBuffers DecideAction() |
|||
{ |
|||
if (m_ModelRunner == null) |
|||
{ |
|||
m_LastActionBuffer = ActionBuffers.Empty; |
|||
} |
|||
else |
|||
{ |
|||
m_ModelRunner?.DecideBatch(); |
|||
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId); |
|||
} |
|||
return ref m_LastActionBuffer; |
|||
} |
|||
|
|||
public void Dispose() |
|||
{ |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 30a25b3276c294e5eb07b57fc1af4bdb |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
// Buffer for C# training
|
|||
|
|||
using System; |
|||
using System.Linq; |
|||
using Unity.Barracuda; |
|||
using System.Collections.Generic; |
|||
using Unity.MLAgents.Actuators; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
internal struct Transition |
|||
{ |
|||
public List<Tensor> state; |
|||
public ActionBuffers action; |
|||
public float reward; |
|||
public List<Tensor> nextState; |
|||
} |
|||
|
|||
internal class ReplayBuffer |
|||
{ |
|||
List<Transition> m_Buffer; |
|||
int currentIndex; |
|||
int m_MaxSize; |
|||
|
|||
public ReplayBuffer(int maxSize) |
|||
{ |
|||
m_Buffer = new List<Transition>(); |
|||
m_Buffer.Capacity = maxSize; |
|||
m_MaxSize = maxSize; |
|||
} |
|||
|
|||
public void Push(AgentInfo info, List<Tensor> state, List<Tensor> nextState) |
|||
{ |
|||
m_Buffer[currentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState}; |
|||
currentIndex += 1; |
|||
currentIndex = currentIndex % m_MaxSize; |
|||
} |
|||
|
|||
public Transition[] Sample(int batchSize) |
|||
{ |
|||
var indexList = SampleIndex(batchSize); |
|||
var samples = new Transition[batchSize]; |
|||
for (var i = 0; i < batchSize; i++) |
|||
{ |
|||
samples[i] = m_Buffer[indexList[i]]; |
|||
} |
|||
return samples; |
|||
} |
|||
|
|||
private List<int> SampleIndex(int batchSize) |
|||
{ |
|||
Random random = new Random(); |
|||
HashSet<int> index = new HashSet<int>(); |
|||
|
|||
while (index.Count < batchSize) |
|||
{ |
|||
index.Add(random.Next(m_Buffer.Count)); |
|||
} |
|||
return index.ToList(); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: be3c5834a200742ed983cd073dd69f9a |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
// Trainer for C# training. One trainer per behavior.
|
|||
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using Unity.MLAgents; |
|||
using Unity.MLAgents.Actuators; |
|||
using Unity.MLAgents.Sensors; |
|||
using Unity.MLAgents.Analytics; |
|||
using Unity.MLAgents.Inference; |
|||
|
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
internal class Trainer: IDisposable |
|||
{ |
|||
ReplayBuffer m_Buffer; |
|||
TrainingModelRunner m_ModelRunner; |
|||
string m_behaviorName; |
|||
int m_BufferSize; |
|||
int batchSize; |
|||
float GAMMA; |
|||
|
|||
public Trainer(string behaviorName, ActionSpec actionSpec, int seed=0) |
|||
{ |
|||
m_behaviorName = behaviorName; |
|||
m_Buffer = new ReplayBuffer(m_BufferSize); |
|||
m_ModelRunner = new TrainingModelRunner(actionSpec, seed); |
|||
Academy.Instance.TrainerUpdate += Update; |
|||
} |
|||
|
|||
public string BehaviorName |
|||
{ |
|||
get => m_behaviorName; |
|||
} |
|||
|
|||
public TrainingModelRunner TrainerModelRunner |
|||
{ |
|||
get => m_ModelRunner; |
|||
} |
|||
|
|||
public void Dispose() |
|||
{ |
|||
Academy.Instance.TrainerUpdate -= Update; |
|||
} |
|||
|
|||
public void Update() |
|||
{ |
|||
var samples = m_Buffer.Sample(batchSize); |
|||
// states = [s.state for s in samples]
|
|||
// actions = [s.action for s in samples]
|
|||
// q_values = policy_net(states).gather(1, action_batch)
|
|||
|
|||
// next_states = [s.next_state for s in samples]
|
|||
// rewards = [s.reward for s in samples]
|
|||
// next_state_values = target_net(non_final_next_states).max(1)[0]
|
|||
// expected_q_values = (next_state_values * GAMMA) + rewards
|
|||
|
|||
// loss = MSE(q_values, expected_q_values);
|
|||
// m_ModelRunner.model = Barracuda.ModelUpdate(m_ModelRunner.model, loss);
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 8dd9e7f1621bd487998fd883b2518733 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue