浏览代码

init training state

/ai-hw-2021
Ruo-Ping Dong 3 年前
当前提交
d2781be0
共有 1 个文件被更改,包括 9 次插入4 次删除
  1. 13
      com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs

13
com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs


using Unity.MLAgents.Sensors;
using UnityEngine;
using Unity.MLAgents.Inference.Utils;
using System.Linq;
namespace Unity.MLAgents
{

Model m_Model;
IWorker m_Engine;
bool m_Verbose = false;
string[] m_OutputNames;
string[] m_TrainingOutputNames;
IReadOnlyList<TensorProxy> m_TrainingInputs;
List<TensorProxy> m_TrainingOutputs;

m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose);
m_TrainingInputs = barracudaModel.GetTrainingInputTensors();
m_OutputNames = barracudaModel.GetOutputNames();
m_TrainingOutputNames = barracudaModel.GetTrainingOutputNames();
m_TensorGenerator = new TensorGenerator(
seed, m_TensorAllocator, m_Memories, barracudaModel);

void InitializeTrainingState()
{
// TODO: initialize m_TrainingState
var initState = m_Model.GetTensorByName(TensorNames.InitialTrainingState);
m_TrainingState = new TensorProxy{
name = TensorNames.InitialTrainingState,
valueType = TensorProxy.TensorType.FloatingPoint,
data = initState,
shape = initState.shape.ToArray().Select(i => (long)i).ToArray()
};
}
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs)

// Execute the Model
m_Engine.Execute(m_InputsByName);
FetchBarracudaOutputs(m_OutputNames);
FetchBarracudaOutputs(m_TrainingOutputNames);
// Update the outputs
m_TensorApplier.ApplyTensors(m_TrainingOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);

正在加载...
取消
保存