Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

68 行
1.9 KiB

// Trainer for C# training. One trainer per behavior.
using System;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
namespace Unity.MLAgents
{
internal class Trainer: IDisposable
{
ReplayBuffer m_Buffer;
TrainingModelRunner m_ModelRunner;
string m_behaviorName;
int m_BufferSize = 1024;
int batchSize = 64;
float GAMMA;
public Trainer(string behaviorName, ActionSpec actionSpec, int seed=0)
{
m_behaviorName = behaviorName;
m_Buffer = new ReplayBuffer(m_BufferSize);
m_ModelRunner = new TrainingModelRunner(actionSpec, seed);
Academy.Instance.TrainerUpdate += Update;
}
public string BehaviorName
{
get => m_behaviorName;
}
public ReplayBuffer Buffer
{
get => m_Buffer;
}
public TrainingModelRunner TrainerModelRunner
{
get => m_ModelRunner;
}
public void Dispose()
{
Academy.Instance.TrainerUpdate -= Update;
}
public void Update()
{
if (m_Buffer.Count < batchSize * 2)
{
return;
}
var samples = m_Buffer.SampleBatch(batchSize);
// states = [s.state for s in samples]
// actions = [s.action for s in samples]
// q_values = policy_net(states).gather(1, action_batch)
// next_states = [s.next_state for s in samples]
// rewards = [s.reward for s in samples]
// next_state_values = target_net(non_final_next_states).max(1)[0]
// expected_q_values = (next_state_values * GAMMA) + rewards
// loss = MSE(q_values, expected_q_values);
// m_ModelRunner.model = Barracuda.ModelUpdate(m_ModelRunner.model, loss);
}
}
}