浏览代码

add config class

/ai-hw-2021
Ruo-Ping Dong 4 年前
当前提交
0ec29858
共有 2 个文件被更改,包括 19 次插入10 次删除
  1. 2
      com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs
  2. 27
      com.unity.ml-agents/Runtime/Trainer.cs

2
com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs


// barracudaModel = ModelLoader.Load(new NNModel());
barracudaModel = ModelLoader.Load(model);
m_Model = barracudaModel;
WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharpBurst;
WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharp;
m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose);
m_InferenceInputs = barracudaModel.GetInputTensors();

27
com.unity.ml-agents/Runtime/Trainer.cs


using System;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;
internal class TrainerConfig
{
public int bufferSize = 1024;
public int batchSize = 64;
public float gamma = 0.99f;
}
int m_BufferSize = 1024;
int batchSize = 64;
float GAMMA;
TrainerConfig m_Config;
int m_TrainingStep;
public Trainer(string behaviorName, ActionSpec actionSpec, NNModel model, int seed=0)
public Trainer(string behaviorName, ActionSpec actionSpec, NNModel model, int seed=0, TrainerConfig config=null)
m_Config = config ?? new TrainerConfig();
m_Buffer = new ReplayBuffer(m_BufferSize);
m_Buffer = new ReplayBuffer(m_Config.bufferSize);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, seed);
Academy.Instance.TrainerUpdate += Update;
}

public void Update()
{
if (m_Buffer.Count < batchSize * 2)
if (m_Buffer.Count < m_Config.batchSize * 2)
var samples = m_Buffer.SampleBatch(batchSize);
var samples = m_Buffer.SampleBatch(m_Config.batchSize);
// q_values = policy_net(states).gather(1, action_batch)
// q_values = policy_net(states).gather(1, actions)
// next_states = [s.next_state for s in samples]
// rewards = [s.reward for s in samples]

// loss = MSE(q_values, expected_q_values);
// m_ModelRunner.model = Barracuda.ModelUpdate(m_ModelRunner.model, loss);
m_TrainingStep += 1;
}
}
}
正在加载...
取消
保存