浏览代码

[bug-fix] Fix regression in --initialize-from feature (#4086)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
09c7787c
共有 2 个文件被更改,包括 17 次插入3 次删除
  1. 2
      ml-agents/mlagents/trainers/learn.py
  2. 18
      ml-agents/mlagents/trainers/tests/test_learn.py

2
ml-agents/mlagents/trainers/learn.py


base_path = "results"
write_path = os.path.join(base_path, checkpoint_settings.run_id)
maybe_init_path = (
os.path.join(base_path, checkpoint_settings.run_id)
os.path.join(base_path, checkpoint_settings.initialize_from)
if checkpoint_settings.initialize_from
else None
)

18
ml-agents/mlagents/trainers/tests/test_learn.py


{}
"""
MOCK_INITIALIZE_YAML = """
behaviors:
{}
checkpoint_settings:
initialize_from: notuselessrun
"""
MOCK_PARAMETER_YAML = """
behaviors:
{}

seed: 9870
checkpoint_settings:
run_id: uselessrun
initialize_from: notuselessrun
debug: false
"""

mock_env.external_brain_names = []
mock_env.academy_name = "TestAcademyName"
create_environment_factory.return_value = mock_env
load_config.return_value = yaml.safe_load(MOCK_YAML)
load_config.return_value = yaml.safe_load(MOCK_INITIALIZE_YAML)
mock_init = MagicMock(return_value=None)
with patch.object(TrainerController, "__init__", mock_init):

sampler_manager_mock.return_value,
None,
)
handle_dir_mock.assert_called_once_with("results/ppo", False, False, None)
handle_dir_mock.assert_called_once_with(
"results/ppo", False, False, "results/notuselessrun"
)
write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs")
write_run_options_mock.assert_called_once_with("results/ppo", options)
StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py

assert opt.checkpoint_settings.resume is False
assert opt.checkpoint_settings.inference is False
assert opt.checkpoint_settings.run_id == "ppo"
assert opt.checkpoint_settings.initialize_from is None
assert opt.env_settings.seed == -1
assert opt.env_settings.base_port == 5005
assert opt.env_settings.num_envs == 1

"--seed=7890",
"--train",
"--base-port=4004",
"--initialize-from=testdir",
"--num-envs=2",
"--no-graphics",
"--debug",

assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "myawesomerun"
assert opt.checkpoint_settings.initialize_from == "testdir"
assert opt.env_settings.seed == 7890
assert opt.env_settings.base_port == 4004
assert opt.env_settings.num_envs == 2

assert opt.env_settings.env_path == "./oldenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "uselessrun"
assert opt.checkpoint_settings.initialize_from == "notuselessrun"
assert opt.env_settings.seed == 9870
assert opt.env_settings.base_port == 4001
assert opt.env_settings.num_envs == 4

正在加载...
取消
保存