浏览代码

add target network prototype

/ai-hw-2021
Ruo-Ping Dong 3 年前
当前提交
ae20c760
共有 1 个文件被更改,包括 12 次插入0 次删除
  1. 12
      com.unity.ml-agents/Runtime/Trainer.cs

12
com.unity.ml-agents/Runtime/Trainer.cs


public int bufferSize = 1024;
public int batchSize = 64;
public float gamma = 0.99f;
public int updateTargetFreq = 200;
}
internal class Trainer: IDisposable

TrainingModelRunner m_TargetModelRunner;
string m_behaviorName;
TrainerConfig m_Config;
int m_TrainingStep;

m_behaviorName = behaviorName;
m_Buffer = new ReplayBuffer(m_Config.bufferSize);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, seed);
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, seed);
// copy weights from model to target model
// m_TargetModelRunner.model.weights = m_ModelRunner.model.weights
Academy.Instance.TrainerUpdate += Update;
}

// loss = MSE(q_values, expected_q_values);
// m_ModelRunner.model = Barracuda.ModelUpdate(m_ModelRunner.model, loss);
// Update target network
if (m_TrainingStep % m_Config.updateTargetFreq == 0)
{
// copy weights
}
m_TrainingStep += 1;
}
正在加载...
取消
保存