Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

79 行
2.2 KiB

// Trainer for C# training. One trainer per behavior.
using System;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;
namespace Unity.MLAgents
{
internal class TrainerConfig
{
public int bufferSize = 100;
public int batchSize = 4;
public float gamma = 0.99f;
public float learningRate = 0.0005f;
public int updateTargetFreq = 200;
}
internal class Trainer: IDisposable
{
ReplayBuffer m_Buffer;
TrainingModelRunner m_ModelRunner;
TrainingModelRunner m_TargetModelRunner;
string m_behaviorName;
TrainerConfig m_Config;
int m_TrainingStep;
public Trainer(string behaviorName, ActionSpec actionSpec, NNModel model, int seed=0, TrainerConfig config=null)
{
m_Config = config ?? new TrainerConfig();
m_behaviorName = behaviorName;
m_Buffer = new ReplayBuffer(m_Config.bufferSize);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
// copy weights from model to target model
// m_TargetModelRunner.model.weights = m_ModelRunner.model.weights
Academy.Instance.TrainerUpdate += Update;
}
public string BehaviorName
{
get => m_behaviorName;
}
public ReplayBuffer Buffer
{
get => m_Buffer;
}
public TrainingModelRunner TrainerModelRunner
{
get => m_ModelRunner;
}
public void Dispose()
{
Academy.Instance.TrainerUpdate -= Update;
}
public void Update()
{
if (m_Buffer.Count < m_Config.batchSize * 2)
{
return;
}
var samples = m_Buffer.SampleBatch(m_Config.batchSize);
m_ModelRunner.UpdateModel(samples);
// Update target network
if (m_TrainingStep % m_Config.updateTargetFreq == 0)
{
// copy weights
}
m_TrainingStep += 1;
}
}
}