浏览代码

temporarily initialize model with trained model (to be reverted)

/ai-hw-2021
Ruo-Ping Dong 4 年前
当前提交
d27ae3fc
共有 5 个文件被更改,包括 10 次插入7 次删除
  1. 2
      com.unity.ml-agents/Runtime/Academy.cs
  2. 2
      com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
  3. 6
      com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs
  4. 3
      com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs
  5. 4
      com.unity.ml-agents/Runtime/Trainer.cs

2
com.unity.ml-agents/Runtime/Academy.cs


var trainer = m_Trainers.Find(x => x.BehaviorName == behaviorName);
if (trainer == null)
{
trainer = new Trainer(behaviorName, actionSpec);
trainer = new Trainer(behaviorName, actionSpec, model);
m_Trainers.Add(trainer);
}
return trainer;

2
com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs


return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName);
}
case BehaviorType.InEditorTraining:
return new TrainingPolicy(actionSpec, m_BehaviorName);
return new TrainingPolicy(actionSpec, m_BehaviorName, m_Model);
case BehaviorType.Default:
if (Academy.Instance.IsCommunicatorOn)
{

6
com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs


/// </exception>
public TrainingModelRunner(
ActionSpec actionSpec,
NNModel model,
int seed = 0)
{
Model barracudaModel;

barracudaModel = ModelLoader.Load(new NNModel());
// barracudaModel = ModelLoader.Load(new NNModel());
barracudaModel = ModelLoader.Load(model);
m_Model = barracudaModel;
m_InferenceInputs = barracudaModel.GetInputTensors();
m_OutputNames = barracudaModel.GetOutputNames();

3
com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs


/// <inheritdoc />
public TrainingPolicy(
ActionSpec actionSpec,
string behaviorName
string behaviorName,
NNModel model
)
{
var trainer = Academy.Instance.GetOrCreateTrainer(behaviorName, actionSpec, model);

4
com.unity.ml-agents/Runtime/Trainer.cs


int batchSize = 64;
float GAMMA;
public Trainer(string behaviorName, ActionSpec actionSpec, int seed=0)
public Trainer(string behaviorName, ActionSpec actionSpec, NNModel model, int seed=0)
m_ModelRunner = new TrainingModelRunner(actionSpec, seed);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, seed);
Academy.Instance.TrainerUpdate += Update;
}

正在加载...
取消
保存