浏览代码

Move seed randomization to learn.py (#1071)

* Move seed randomization to learn.py

* Remove print statement
/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
9538d699
共有 2 个文件被更改,包括 9 次插入5 次删除
  1. 12
      python/learn.py
  2. 2
      python/unitytrainers/trainer_controller.py

12
python/learn.py


import os
import multiprocessing
import numpy as np
from unitytrainers.trainer_controller import TrainerController
from unitytrainers.exception import TrainerError

TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml"))
def run_training(sub_id):
def run_training(sub_id, use_seed):
load_model, train_model, worker_id + sub_id, keep_checkpoints, lesson, seed,
load_model, train_model, worker_id + sub_id, keep_checkpoints, lesson, use_seed,
docker_target_name, TRAINER_CONFIG_PATH, no_graphics)
tc.start_learning()

jobs = []
for i in range(num_runs):
p = multiprocessing.Process(target=run_training, args=(i,))
if seed == -1:
use_seed = np.random.randint(0, 9999)
else:
use_seed = seed
p = multiprocessing.Process(target=run_training, args=(i, use_seed))
jobs.append(p)
p.start()

2
python/unitytrainers/trainer_controller.py


self.worker_id = worker_id
self.keep_checkpoints = keep_checkpoints
self.trainers = {}
if seed == -1:
seed = np.random.randint(0, 999999)
self.seed = seed
np.random.seed(self.seed)
tf.set_random_seed(self.seed)

正在加载...
取消
保存