|
|
|
|
|
|
bool m_Verbose = false; |
|
|
|
IReadOnlyList<TensorProxy> m_TrainingInputs; |
|
|
|
IReadOnlyList<TensorProxy> m_InferenceInputs; |
|
|
|
string[] m_TrainingOutputNames; |
|
|
|
string[] m_InferenceOutputNames; |
|
|
|
List<TensorProxy> m_TrainingOutputs; |
|
|
|
Dictionary<string, Tensor> m_InputsByName; |
|
|
|
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>(); |
|
|
|
|
|
|
actionSpec, seed, m_TensorAllocator, barracudaModel); |
|
|
|
m_InputsByName = new Dictionary<string, Tensor>(); |
|
|
|
m_TrainingOutputs = new List<TensorProxy>(); |
|
|
|
m_TrainingOutputNames = new string[] {TensorNames.TrainingStateOut, TensorNames.OuputLoss}; |
|
|
|
m_InferenceOutputNames = new string[] {TensorNames.TrainingOutput}; |
|
|
|
m_Buffer = buffer; |
|
|
|
InitializeTrainingState(); |
|
|
|
} |
|
|
|
|
|
|
{ |
|
|
|
name = TensorNames.InitialTrainingState, |
|
|
|
valueType = TensorProxy.TensorType.FloatingPoint, |
|
|
|
data = initState, |
|
|
|
data = initState.DeepCopy(), |
|
|
|
shape = initState.shape.ToArray().Select(i => (long)i).ToArray() |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
// Execute the Model
|
|
|
|
m_Engine.Execute(m_InputsByName); |
|
|
|
|
|
|
|
FetchBarracudaOutputs(new string[] { TensorNames.TrainingOutput }); |
|
|
|
FetchBarracudaOutputs(m_InferenceOutputNames); |
|
|
|
|
|
|
|
// Update the outputs
|
|
|
|
m_TensorApplier.ApplyTensors(m_TrainingOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); |
|
|
|
|
|
|
m_OrderedAgentsRequestingDecisions.Clear(); |
|
|
|
} |
|
|
|
|
|
|
|
public void UpdateModel(List<Transition> transitions) |
|
|
|
public float UpdateModel(List<Transition> transitions) |
|
|
|
return; |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions, m_TrainingState, true); |
|
|
|
|
|
|
m_Engine.Execute(m_InputsByName); |
|
|
|
|
|
|
|
// Update the model
|
|
|
|
FetchBarracudaOutputs(new string[] { TensorNames.TrainingStateOut }); |
|
|
|
FetchBarracudaOutputs(m_TrainingOutputNames); |
|
|
|
TensorUtils.CopyTensor(m_TrainingOutputs[0], m_TrainingState); |
|
|
|
|
|
|
|
// UnityEngine.Debug.Log(m_TrainingState.data[0]);
|
|
|
|
|
|
|
// }
|
|
|
|
// throw new System.Exception("STOP");
|
|
|
|
return m_TrainingOutputs[1].data[0]; |
|
|
|
} |
|
|
|
|
|
|
|
public ActionBuffers GetAction(int agentId) |
|
|
|