浏览代码

fix Unity-Technologies/ml-agents#1041 (#1102)

/develop-generalizationTraining-TrainerController
Arthur Juliani 6 年前
当前提交
567ad3f0
共有 1 个文件被更改,包括 43 次插入38 次删除
  1. 81
      python/learn.py

81
python/learn.py


from unitytrainers.trainer_controller import TrainerController
from unitytrainers.exception import TrainerError
def run_training(sub_id, use_seed, options):
# Docker Parameters
if options['--docker-target-name'] == 'Empty':
docker_target_name = ''
else:
docker_target_name = options['--docker-target-name']
# General parameters
run_id = options['--run-id']
num_runs = int(options['--num-runs'])
seed = int(options['--seed'])
load_model = options['--load']
train_model = options['--train']
save_freq = int(options['--save-freq'])
env_path = options['<env>']
keep_checkpoints = int(options['--keep-checkpoints'])
worker_id = int(options['--worker-id'])
curriculum_file = str(options['--curriculum'])
if curriculum_file == "None":
curriculum_file = None
lesson = int(options['--lesson'])
fast_simulation = not bool(options['--slow'])
no_graphics = options['--no-graphics']
# Constants
# Assumption that this yaml is present in same dir as this file
base_path = os.path.dirname(__file__)
TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml"))
if env_path is None and num_runs > 1:
raise TrainerError("It is not possible to launch more than one concurrent training session "
"when training from the editor")
tc = TrainerController(env_path, run_id + "-" + str(sub_id), save_freq, curriculum_file, fast_simulation,
load_model, train_model, worker_id + sub_id, keep_checkpoints, lesson, use_seed,
docker_target_name, TRAINER_CONFIG_PATH, no_graphics)
tc.start_learning()
if __name__ == '__main__':
print('''

options = docopt(_USAGE)
logger.info(options)
# Docker Parameters
if options['--docker-target-name'] == 'Empty':
docker_target_name = ''
else:
docker_target_name = options['--docker-target-name']
# General parameters
run_id = options['--run-id']
load_model = options['--load']
train_model = options['--train']
save_freq = int(options['--save-freq'])
env_path = options['<env>']
keep_checkpoints = int(options['--keep-checkpoints'])
worker_id = int(options['--worker-id'])
curriculum_file = str(options['--curriculum'])
if curriculum_file == "None":
curriculum_file = None
lesson = int(options['--lesson'])
fast_simulation = not bool(options['--slow'])
no_graphics = options['--no-graphics']
# Constants
# Assumption that this yaml is present in same dir as this file
base_path = os.path.dirname(__file__)
TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml"))
def run_training(sub_id, use_seed):
tc = TrainerController(env_path, run_id + "-" + str(sub_id), save_freq, curriculum_file, fast_simulation,
load_model, train_model, worker_id + sub_id, keep_checkpoints, lesson, use_seed,
docker_target_name, TRAINER_CONFIG_PATH, no_graphics)
tc.start_learning()
if env_path is None and num_runs > 1:
raise TrainerError("It is not possible to launch more than one concurrent training session "
"when training from the editor")
jobs = []
for i in range(num_runs):

use_seed = seed
p = multiprocessing.Process(target=run_training, args=(i, use_seed))
p = multiprocessing.Process(target=run_training, args=(i, use_seed, options))
p.start()
p.start()
正在加载...
取消
保存