|
|
|
|
|
|
bool m_Verbose = false; |
|
|
|
string[] m_TrainingOutputNames; |
|
|
|
IReadOnlyList<TensorProxy> m_TrainingInputs; |
|
|
|
IReadOnlyList<TensorProxy> m_InferenceInputs; |
|
|
|
List<TensorProxy> m_TrainingOutputs; |
|
|
|
Dictionary<string, Tensor> m_InputsByName; |
|
|
|
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>(); |
|
|
|
|
|
|
m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); |
|
|
|
|
|
|
|
m_TrainingInputs = barracudaModel.GetTrainingInputTensors(); |
|
|
|
List<TensorProxy> infTensors = new List<TensorProxy>(); |
|
|
|
foreach(var tensor in m_TrainingInputs) |
|
|
|
{ |
|
|
|
if (tensor.name == TensorNames.Observations || tensor.name == TensorNames.BatchSizePlaceholder) |
|
|
|
{ |
|
|
|
infTensors.Add(tensor); |
|
|
|
} |
|
|
|
} |
|
|
|
m_InferenceInputs = (IReadOnlyList<TensorProxy>) infTensors; |
|
|
|
m_TrainingOutputNames = barracudaModel.GetTrainingOutputNames(); |
|
|
|
m_TensorGenerator = new TensorGenerator( |
|
|
|
seed, m_TensorAllocator, m_Memories, barracudaModel); |
|
|
|
|
|
|
m_InputsByName = new Dictionary<string, Tensor>(); |
|
|
|
m_TrainingOutputs = new List<TensorProxy>(); |
|
|
|
m_Buffer = buffer; |
|
|
|
InitializeTrainingState(); |
|
|
|
} |
|
|
|
|
|
|
|
void InitializeTrainingState() |
|
|
|
|
|
|
var inp = infInputs[i]; |
|
|
|
m_InputsByName[inp.name] = inp.data; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public ITensorAllocator Allocator |
|
|
|
{ |
|
|
|
get => m_TensorAllocator; |
|
|
|
} |
|
|
|
|
|
|
|
public void Dispose() |
|
|
|
|
|
|
|
|
|
|
public IReadOnlyList<TensorProxy> GetInputTensors() |
|
|
|
{ |
|
|
|
return m_Model.GetInputTensors(); |
|
|
|
return m_Model.GetTrainingObservationInputTensors(); |
|
|
|
} |
|
|
|
|
|
|
|
public void DecideBatch() |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
// Prepare the input tensors to be feed into the engine
|
|
|
|
m_TensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Infos); |
|
|
|
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos); |
|
|
|
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Buffer.SampleDummyBatch(currentBatchSize), m_TrainingState); |
|
|
|
|
|
|
|
PrepareBarracudaInputs(m_TrainingInputs); |
|
|
|