|
|
|
|
|
|
public int bufferSize = 1024; |
|
|
|
public int batchSize = 64; |
|
|
|
public float gamma = 0.99f; |
|
|
|
public int updateTargetFreq = 200; |
|
|
|
} |
|
|
|
|
|
|
|
internal class Trainer: IDisposable |
|
|
|
|
|
|
TrainingModelRunner m_TargetModelRunner; |
|
|
|
string m_behaviorName; |
|
|
|
TrainerConfig m_Config; |
|
|
|
int m_TrainingStep; |
|
|
|
|
|
|
m_behaviorName = behaviorName; |
|
|
|
m_Buffer = new ReplayBuffer(m_Config.bufferSize); |
|
|
|
m_ModelRunner = new TrainingModelRunner(actionSpec, model, seed); |
|
|
|
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, seed); |
|
|
|
// copy weights from model to target model
|
|
|
|
// m_TargetModelRunner.model.weights = m_ModelRunner.model.weights
|
|
|
|
Academy.Instance.TrainerUpdate += Update; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// loss = MSE(q_values, expected_q_values);
|
|
|
|
// m_ModelRunner.model = Barracuda.ModelUpdate(m_ModelRunner.model, loss);
|
|
|
|
|
|
|
|
|
|
|
|
// Update target network
|
|
|
|
if (m_TrainingStep % m_Config.updateTargetFreq == 0) |
|
|
|
{ |
|
|
|
// copy weights
|
|
|
|
} |
|
|
|
|
|
|
|
m_TrainingStep += 1; |
|
|
|
} |