|
|
|
|
|
|
using Unity.MLAgents.Sensors; |
|
|
|
using UnityEngine; |
|
|
|
using Unity.MLAgents.Inference.Utils; |
|
|
|
using System.Linq; |
|
|
|
|
|
|
|
namespace Unity.MLAgents |
|
|
|
{ |
|
|
|
|
|
|
Model m_Model; |
|
|
|
IWorker m_Engine; |
|
|
|
bool m_Verbose = false; |
|
|
|
string[] m_OutputNames; |
|
|
|
string[] m_TrainingOutputNames; |
|
|
|
IReadOnlyList<TensorProxy> m_TrainingInputs; |
|
|
|
List<TensorProxy> m_TrainingOutputs; |
|
|
|
|
|
|
m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); |
|
|
|
|
|
|
|
m_TrainingInputs = barracudaModel.GetTrainingInputTensors(); |
|
|
|
m_OutputNames = barracudaModel.GetOutputNames(); |
|
|
|
m_TrainingOutputNames = barracudaModel.GetTrainingOutputNames(); |
|
|
|
m_TensorGenerator = new TensorGenerator( |
|
|
|
seed, m_TensorAllocator, m_Memories, barracudaModel); |
|
|
|
|
|
|
|
|
|
|
void InitializeTrainingState() |
|
|
|
{ |
|
|
|
// TODO: initialize m_TrainingState
|
|
|
|
var initState = m_Model.GetTensorByName(TensorNames.InitialTrainingState); |
|
|
|
m_TrainingState = new TensorProxy{ |
|
|
|
name = TensorNames.InitialTrainingState, |
|
|
|
valueType = TensorProxy.TensorType.FloatingPoint, |
|
|
|
data = initState, |
|
|
|
shape = initState.shape.ToArray().Select(i => (long)i).ToArray() |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs) |
|
|
|
|
|
|
// Execute the Model
|
|
|
|
m_Engine.Execute(m_InputsByName); |
|
|
|
|
|
|
|
FetchBarracudaOutputs(m_OutputNames); |
|
|
|
FetchBarracudaOutputs(m_TrainingOutputNames); |
|
|
|
|
|
|
|
// Update the outputs
|
|
|
|
m_TensorApplier.ApplyTensors(m_TrainingOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); |
|
|
|