您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
79 行
2.2 KiB
79 行
2.2 KiB
// Trainer for C# training. One trainer per behavior.
|
|
|
|
using System;
|
|
using Unity.MLAgents.Actuators;
|
|
using Unity.Barracuda;
|
|
using UnityEngine;
|
|
|
|
namespace Unity.MLAgents
|
|
{
|
|
internal class TrainerConfig
|
|
{
|
|
public int bufferSize = 100;
|
|
public int batchSize = 4;
|
|
public float gamma = 0.99f;
|
|
public float learningRate = 0.0005f;
|
|
public int updateTargetFreq = 200;
|
|
}
|
|
|
|
internal class Trainer: IDisposable
|
|
{
|
|
ReplayBuffer m_Buffer;
|
|
TrainingModelRunner m_ModelRunner;
|
|
TrainingModelRunner m_TargetModelRunner;
|
|
string m_behaviorName;
|
|
TrainerConfig m_Config;
|
|
int m_TrainingStep;
|
|
|
|
public Trainer(string behaviorName, ActionSpec actionSpec, NNModel model, int seed=0, TrainerConfig config=null)
|
|
{
|
|
m_Config = config ?? new TrainerConfig();
|
|
m_behaviorName = behaviorName;
|
|
m_Buffer = new ReplayBuffer(m_Config.bufferSize);
|
|
m_ModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
|
|
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
|
|
// copy weights from model to target model
|
|
// m_TargetModelRunner.model.weights = m_ModelRunner.model.weights
|
|
Academy.Instance.TrainerUpdate += Update;
|
|
}
|
|
|
|
public string BehaviorName
|
|
{
|
|
get => m_behaviorName;
|
|
}
|
|
|
|
public ReplayBuffer Buffer
|
|
{
|
|
get => m_Buffer;
|
|
}
|
|
|
|
public TrainingModelRunner TrainerModelRunner
|
|
{
|
|
get => m_ModelRunner;
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
Academy.Instance.TrainerUpdate -= Update;
|
|
}
|
|
|
|
public void Update()
|
|
{
|
|
if (m_Buffer.Count < m_Config.batchSize * 2)
|
|
{
|
|
return;
|
|
}
|
|
|
|
var samples = m_Buffer.SampleBatch(m_Config.batchSize);
|
|
m_ModelRunner.UpdateModel(samples);
|
|
|
|
// Update target network
|
|
if (m_TrainingStep % m_Config.updateTargetFreq == 0)
|
|
{
|
|
// copy weights
|
|
}
|
|
|
|
m_TrainingStep += 1;
|
|
}
|
|
}
|
|
}
|