|
|
|
|
|
|
from mlagents.trainers.cli_utils import DetectDefault |
|
|
|
from mlagents_envs.exception import UnityEnvironmentException |
|
|
|
from mlagents.trainers.stats import StatsReporter |
|
|
|
from mlagents.trainers.settings import UniformSettings |
|
|
|
|
|
|
|
|
|
|
|
def basic_options(extra_args=None): |
|
|
|
|
|
|
|
|
|
|
MOCK_SAMPLER_CURRICULUM_YAML = """ |
|
|
|
parameter_randomization: |
|
|
|
sampler1: foo |
|
|
|
sampler1: |
|
|
|
uniform: |
|
|
|
min_value: 0.2 |
|
|
|
|
|
|
|
curriculum: |
|
|
|
behavior1: |
|
|
|
|
|
|
@patch("mlagents.trainers.learn.write_run_options") |
|
|
|
@patch("mlagents.trainers.learn.handle_existing_directories") |
|
|
|
@patch("mlagents.trainers.learn.TrainerFactory") |
|
|
|
@patch("mlagents.trainers.learn.SamplerManager") |
|
|
|
@patch("mlagents.trainers.learn.SubprocessEnvManager") |
|
|
|
@patch("mlagents.trainers.learn.create_environment_factory") |
|
|
|
@patch("mlagents.trainers.settings.load_config") |
|
|
|
|
|
|
subproc_env_mock, |
|
|
|
sampler_manager_mock, |
|
|
|
trainer_factory_mock, |
|
|
|
handle_dir_mock, |
|
|
|
write_run_options_mock, |
|
|
|
|
|
|
options = basic_options() |
|
|
|
learn.run_training(0, options) |
|
|
|
mock_init.assert_called_once_with( |
|
|
|
trainer_factory_mock.return_value, |
|
|
|
"results/ppo", |
|
|
|
"ppo", |
|
|
|
None, |
|
|
|
True, |
|
|
|
0, |
|
|
|
sampler_manager_mock.return_value, |
|
|
|
None, |
|
|
|
trainer_factory_mock.return_value, "results/ppo", "ppo", None, True, 0 |
|
|
|
) |
|
|
|
handle_dir_mock.assert_called_once_with("results/ppo", False, False, None) |
|
|
|
write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs") |
|
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_SAMPLER_CURRICULUM_YAML) |
|
|
|
def test_sampler_configs(mock_file): |
|
|
|
opt = parse_command_line(["mytrainerpath"]) |
|
|
|
assert opt.parameter_randomization == {"sampler1": "foo"} |
|
|
|
assert isinstance(opt.parameter_randomization["sampler1"], UniformSettings) |
|
|
|
assert len(opt.curriculum.keys()) == 2 |
|
|
|
|
|
|
|
|
|
|
|