base_path = "results"
write_path = os.path.join(base_path, checkpoint_settings.run_id)
maybe_init_path = (
os.path.join(base_path, checkpoint_settings.run_id)
os.path.join(base_path, checkpoint_settings.initialize_from)
if checkpoint_settings.initialize_from
else None
)
{}
"""
MOCK_INITIALIZE_YAML = """
behaviors:
checkpoint_settings:
initialize_from: notuselessrun
MOCK_PARAMETER_YAML = """
seed: 9870
run_id: uselessrun
debug: false
mock_env.external_brain_names = []
mock_env.academy_name = "TestAcademyName"
create_environment_factory.return_value = mock_env
load_config.return_value = yaml.safe_load(MOCK_YAML)
load_config.return_value = yaml.safe_load(MOCK_INITIALIZE_YAML)
mock_init = MagicMock(return_value=None)
with patch.object(TrainerController, "__init__", mock_init):
sampler_manager_mock.return_value,
None,
handle_dir_mock.assert_called_once_with("results/ppo", False, False, None)
handle_dir_mock.assert_called_once_with(
"results/ppo", False, False, "results/notuselessrun"
write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs")
write_run_options_mock.assert_called_once_with("results/ppo", options)
StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py
assert opt.checkpoint_settings.resume is False
assert opt.checkpoint_settings.inference is False
assert opt.checkpoint_settings.run_id == "ppo"
assert opt.checkpoint_settings.initialize_from is None
assert opt.env_settings.seed == -1
assert opt.env_settings.base_port == 5005
assert opt.env_settings.num_envs == 1
"--seed=7890",
"--train",
"--base-port=4004",
"--initialize-from=testdir",
"--num-envs=2",
"--no-graphics",
"--debug",
assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "myawesomerun"
assert opt.checkpoint_settings.initialize_from == "testdir"
assert opt.env_settings.seed == 7890
assert opt.env_settings.base_port == 4004
assert opt.env_settings.num_envs == 2
assert opt.env_settings.env_path == "./oldenvfile"
assert opt.checkpoint_settings.run_id == "uselessrun"
assert opt.checkpoint_settings.initialize_from == "notuselessrun"
assert opt.env_settings.seed == 9870
assert opt.env_settings.base_port == 4001
assert opt.env_settings.num_envs == 4