|
|
|
|
|
|
using Unity.MLAgents.Policies; |
|
|
|
using Unity.MLAgents.Sensors; |
|
|
|
using UnityEngine; |
|
|
|
using Unity.MLAgents.Inference.Utils; |
|
|
|
|
|
|
|
namespace Unity.MLAgents |
|
|
|
{ |
|
|
|
|
|
|
string[] m_TrainingOutputNames; |
|
|
|
IReadOnlyList<TensorProxy> m_InferenceInputs; |
|
|
|
IReadOnlyList<TensorProxy> m_TrainingInputs; |
|
|
|
IReadOnlyList<TensorProxy> m_ModelParametersInputs; |
|
|
|
List<TensorProxy> m_InferenceOutputs; |
|
|
|
Dictionary<string, Tensor> m_InputsByName; |
|
|
|
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>(); |
|
|
|
|
|
|
bool m_ObservationsInitialized; |
|
|
|
bool m_TrainingObservationsInitialized; |
|
|
|
|
|
|
|
ReplayBuffer m_Buffer; |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Initializes the Brain with the Model that it will use when selecting actions for
|
|
|
|
|
|
|
public TrainingModelRunner( |
|
|
|
ActionSpec actionSpec, |
|
|
|
NNModel model, |
|
|
|
ReplayBuffer buffer, |
|
|
|
int seed = 0) |
|
|
|
{ |
|
|
|
Model barracudaModel; |
|
|
|
|
|
|
// barracudaModel = ModelLoader.Load(new NNModel());
|
|
|
|
barracudaModel = ModelLoader.Load(model); |
|
|
|
m_Model = barracudaModel; |
|
|
|
WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharp; |
|
|
|
WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharpBurst; |
|
|
|
m_ModelParametersInputs = barracudaModel.GetModelParamTensors(); |
|
|
|
InitializeModelParam(); |
|
|
|
m_OutputNames = barracudaModel.GetOutputNames(); |
|
|
|
m_TrainingOutputNames = barracudaModel.GetTrainingOutputNames(); |
|
|
|
m_TensorGenerator = new TensorGenerator( |
|
|
|
|
|
|
actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel); |
|
|
|
m_InputsByName = new Dictionary<string, Tensor>(); |
|
|
|
m_InferenceOutputs = new List<TensorProxy>(); |
|
|
|
m_Buffer = buffer; |
|
|
|
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs) |
|
|
|
void InitializeModelParam() |
|
|
|
{ |
|
|
|
RandomNormal randomNormal = new RandomNormal(10); |
|
|
|
foreach (var tensor in m_ModelParametersInputs) |
|
|
|
{ |
|
|
|
TensorUtils.RandomInitialize(tensor, randomNormal, m_TensorAllocator); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs, bool training=false) |
|
|
|
{ |
|
|
|
m_InputsByName.Clear(); |
|
|
|
for (var i = 0; i < infInputs.Count; i++) |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
for (var i = 0; i < m_TrainingInputs.Count; i++) |
|
|
|
{ |
|
|
|
var inp = m_TrainingInputs[i]; |
|
|
|
if (m_InputsByName.ContainsKey(inp.name) && training==false) |
|
|
|
{ |
|
|
|
continue; |
|
|
|
} |
|
|
|
m_InputsByName[inp.name] = inp.data; |
|
|
|
} |
|
|
|
|
|
|
|
for (var i = 0; i < m_ModelParametersInputs.Count; i++) |
|
|
|
{ |
|
|
|
var inp = m_ModelParametersInputs[i]; |
|
|
|
m_InputsByName[inp.name] = inp.data; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public void Dispose() |
|
|
|
|
|
|
m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator); |
|
|
|
m_ObservationsInitialized = true; |
|
|
|
} |
|
|
|
if (!m_TrainingObservationsInitialized) |
|
|
|
{ |
|
|
|
// Just grab the first agent in the collection (any will suffice, really).
|
|
|
|
// We check for an empty Collection above, so this will always return successfully.
|
|
|
|
m_TrainingTensorGenerator.InitializeObservations(m_Buffer.SampleDummyBatch(1)[0], m_TensorAllocator); |
|
|
|
m_TrainingObservationsInitialized = true; |
|
|
|
} |
|
|
|
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Buffer.SampleDummyBatch(currentBatchSize)); |
|
|
|
|
|
|
|
PrepareBarracudaInputs(m_InferenceInputs); |
|
|
|
|
|
|
|
|
|
|
m_TrainingTensorGenerator.InitializeObservations(transitions[0], m_TensorAllocator); |
|
|
|
m_TrainingObservationsInitialized = true; |
|
|
|
} |
|
|
|
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions); |
|
|
|
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions, true); |
|
|
|
PrepareBarracudaInputs(m_TrainingInputs); |
|
|
|
PrepareBarracudaInputs(m_TrainingInputs, true); |
|
|
|
|
|
|
|
// Execute the Model
|
|
|
|
m_Engine.Execute(m_InputsByName); |
|
|
|
|
|
|
// Update the model
|
|
|
|
// m_Model.weights = m_InferenceOutputs.weights
|
|
|
|
// CopyWeights(w_0, nw_0)
|
|
|
|
} |
|
|
|
|
|
|
|
public ActionBuffers GetAction(int agentId) |
|
|
|