浏览代码

Clean up learn.py (#1106)

/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
2edaf342
共有 1 个文件被更改,包括 38 次插入38 次删除
  1. 76
      python/learn.py

76
python/learn.py


# # Unity ML-Agents Toolkit
# ## ML-Agent Learning
import logging

from docopt import docopt
def run_training(sub_id, use_seed, options):
def run_training(sub_id, run_seed, run_options):
"""
Launches training session.
:param sub_id: Unique id for training session.
:param run_seed: Random seed used for training.
:param run_options: Command line arguments for training.
"""
if options['--docker-target-name'] == 'Empty':
if run_options['--docker-target-name'] == 'Empty':
docker_target_name = options['--docker-target-name']
docker_target_name = run_options['--docker-target-name']
run_id = options['--run-id']
num_runs = int(options['--num-runs'])
seed = int(options['--seed'])
load_model = options['--load']
train_model = options['--train']
save_freq = int(options['--save-freq'])
env_path = options['<env>']
keep_checkpoints = int(options['--keep-checkpoints'])
worker_id = int(options['--worker-id'])
curriculum_file = str(options['--curriculum'])
run_id = run_options['--run-id']
load_model = run_options['--load']
train_model = run_options['--train']
save_freq = int(run_options['--save-freq'])
keep_checkpoints = int(run_options['--keep-checkpoints'])
worker_id = int(run_options['--worker-id'])
curriculum_file = str(run_options['--curriculum'])
lesson = int(options['--lesson'])
fast_simulation = not bool(options['--slow'])
no_graphics = options['--no-graphics']
lesson = int(run_options['--lesson'])
fast_simulation = not bool(run_options['--slow'])
no_graphics = run_options['--no-graphics']
TRAINER_CONFIG_PATH = os.path.abspath(os.path.join(base_path, "trainer_config.yaml"))
if env_path is None and num_runs > 1:
raise TrainerError("It is not possible to launch more than one concurrent training session "
"when training from the editor")
trainer_config_path = os.path.abspath(os.path.join(base_path, "trainer_config.yaml"))
tc = TrainerController(env_path, run_id + "-" + str(sub_id), save_freq, curriculum_file, fast_simulation,
load_model, train_model, worker_id + sub_id, keep_checkpoints, lesson, use_seed,
docker_target_name, TRAINER_CONFIG_PATH, no_graphics)
# Create controller and begin training.
tc = TrainerController(env_path, run_id + "-" + str(sub_id), save_freq, curriculum_file,
fast_simulation, load_model, train_model, worker_id + sub_id,
keep_checkpoints, lesson, run_seed, docker_target_name,
trainer_config_path, no_graphics)
if __name__ == '__main__':
print('''

--keep-checkpoints=<n> How many model checkpoints to keep [default: 5].
--lesson=<n> Start learning from this lesson [default: 0].
--load Whether to load the model or randomly initialize [default: False].
--run-id=<path> The sub-directory name for model and summary statistics [default: ppo].
--run-id=<path> The directory name for model and summary statistics [default: ppo].
--worker-id=<n> Number to add to communication port (5005). Used for multi-environment [default: 0].
--docker-target-name=<dt> Docker Volume to store curriculum, executable and model files [default: Empty].
--no-graphics Whether to run the Unity simulator in no-graphics mode [default: False].
--worker-id=<n> Number to add to communication port (5005) [default: 0].
--docker-target-name=<dt> Docker volume to store training-specific files [default: Empty].
--no-graphics Whether to run the environment in no-graphics mode [default: False].
'''
options = docopt(_USAGE)

env_path = options['<env>']
if env_path is None and num_runs > 1:
raise TrainerError("It is not possible to launch more than one concurrent training session "
"when training from the editor")
use_seed = np.random.randint(0, 9999)
else:
use_seed = seed
p = multiprocessing.Process(target=run_training, args=(i, use_seed, options))
seed = np.random.randint(0, 9999)
p = multiprocessing.Process(target=run_training, args=(i, seed, options))
p.start()
p.start()
正在加载...
取消
保存