|
|
|
|
|
|
using System; |
|
|
|
using Unity.MLAgents.Actuators; |
|
|
|
using Unity.Barracuda; |
|
|
|
|
|
|
|
using UnityEngine; |
|
|
|
internal class TrainerConfig |
|
|
|
{ |
|
|
|
public int bufferSize = 1024; |
|
|
|
public int batchSize = 64; |
|
|
|
public float gamma = 0.99f; |
|
|
|
} |
|
|
|
|
|
|
|
int m_BufferSize = 1024; |
|
|
|
int batchSize = 64; |
|
|
|
float GAMMA; |
|
|
|
TrainerConfig m_Config; |
|
|
|
int m_TrainingStep; |
|
|
|
public Trainer(string behaviorName, ActionSpec actionSpec, NNModel model, int seed=0) |
|
|
|
public Trainer(string behaviorName, ActionSpec actionSpec, NNModel model, int seed=0, TrainerConfig config=null) |
|
|
|
m_Config = config ?? new TrainerConfig(); |
|
|
|
m_Buffer = new ReplayBuffer(m_BufferSize); |
|
|
|
m_Buffer = new ReplayBuffer(m_Config.bufferSize); |
|
|
|
m_ModelRunner = new TrainingModelRunner(actionSpec, model, seed); |
|
|
|
Academy.Instance.TrainerUpdate += Update; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
public void Update() |
|
|
|
{ |
|
|
|
if (m_Buffer.Count < batchSize * 2) |
|
|
|
if (m_Buffer.Count < m_Config.batchSize * 2) |
|
|
|
var samples = m_Buffer.SampleBatch(batchSize); |
|
|
|
var samples = m_Buffer.SampleBatch(m_Config.batchSize); |
|
|
|
// q_values = policy_net(states).gather(1, action_batch)
|
|
|
|
// q_values = policy_net(states).gather(1, actions)
|
|
|
|
|
|
|
|
// next_states = [s.next_state for s in samples]
|
|
|
|
// rewards = [s.reward for s in samples]
|
|
|
|
|
|
|
// loss = MSE(q_values, expected_q_values);
|
|
|
|
// m_ModelRunner.model = Barracuda.ModelUpdate(m_ModelRunner.model, loss);
|
|
|
|
|
|
|
|
m_TrainingStep += 1; |
|
|
|
} |
|
|
|
} |
|
|
|
} |