|
|
|
|
|
|
import logging |
|
|
|
|
|
|
|
import os |
|
|
|
import ray |
|
|
|
import multiprocessing |
|
|
|
from docopt import docopt |
|
|
|
|
|
|
|
from unitytrainers.trainer_controller import TrainerController |
|
|
|
|
|
|
base_path = os.path.dirname(__file__) |
|
|
|
TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml")) |
|
|
|
|
|
|
|
@ray.remote |
|
|
|
def run_training(sub_id): |
|
|
|
tc = TrainerController(env_path, run_id+"-"+str(sub_id), save_freq, curriculum_file, fast_simulation, |
|
|
|
load_model, train_model, worker_id+sub_id, keep_checkpoints, lesson, seed, |
|
|
|
|
|
|
ray.init() |
|
|
|
ray.get([run_training.remote(i) for i in range(num_runs)]) |
|
|
|
jobs = [] |
|
|
|
for i in range(num_runs): |
|
|
|
p = multiprocessing.Process(target=run_training, args=(i,)) |
|
|
|
jobs.append(p) |
|
|
|
p.start() |