您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
222 行
7.1 KiB
222 行
7.1 KiB
import pytest
|
|
import yaml
|
|
from unittest.mock import MagicMock, patch, mock_open
|
|
from mlagents.trainers import learn
|
|
from mlagents.trainers.trainer_controller import TrainerController
|
|
from mlagents.trainers.learn import parse_command_line
|
|
from mlagents.trainers.cli_utils import DetectDefault
|
|
from mlagents_envs.exception import UnityEnvironmentException
|
|
from mlagents.trainers.stats import StatsReporter
|
|
|
|
|
|
def basic_options(extra_args=None):
|
|
extra_args = extra_args or {}
|
|
args = ["basic_path"]
|
|
if extra_args:
|
|
args += [f"{k}={v}" for k, v in extra_args.items()]
|
|
return parse_command_line(args)
|
|
|
|
|
|
MOCK_YAML = """
|
|
behaviors:
|
|
{}
|
|
"""
|
|
|
|
MOCK_PARAMETER_YAML = """
|
|
behaviors:
|
|
{}
|
|
env_settings:
|
|
env_path: "./oldenvfile"
|
|
num_envs: 4
|
|
base_port: 4001
|
|
seed: 9870
|
|
checkpoint_settings:
|
|
run_id: uselessrun
|
|
debug: false
|
|
"""
|
|
|
|
MOCK_SAMPLER_CURRICULUM_YAML = """
|
|
parameter_randomization:
|
|
sampler1: foo
|
|
|
|
curriculum:
|
|
behavior1:
|
|
parameters:
|
|
foo: [0.2, 0.5]
|
|
behavior2:
|
|
parameters:
|
|
foo: [0.2, 0.5]
|
|
"""
|
|
|
|
|
|
@patch("mlagents.trainers.learn.write_timing_tree")
|
|
@patch("mlagents.trainers.learn.write_run_options")
|
|
@patch("mlagents.trainers.learn.handle_existing_directories")
|
|
@patch("mlagents.trainers.learn.TrainerFactory")
|
|
@patch("mlagents.trainers.learn.SamplerManager")
|
|
@patch("mlagents.trainers.learn.SubprocessEnvManager")
|
|
@patch("mlagents.trainers.learn.create_environment_factory")
|
|
@patch("mlagents.trainers.settings.load_config")
|
|
def test_run_training(
|
|
load_config,
|
|
create_environment_factory,
|
|
subproc_env_mock,
|
|
sampler_manager_mock,
|
|
trainer_factory_mock,
|
|
handle_dir_mock,
|
|
write_run_options_mock,
|
|
write_timing_tree_mock,
|
|
):
|
|
mock_env = MagicMock()
|
|
mock_env.external_brain_names = []
|
|
mock_env.academy_name = "TestAcademyName"
|
|
create_environment_factory.return_value = mock_env
|
|
load_config.return_value = yaml.safe_load(MOCK_YAML)
|
|
|
|
mock_init = MagicMock(return_value=None)
|
|
with patch.object(TrainerController, "__init__", mock_init):
|
|
with patch.object(TrainerController, "start_learning", MagicMock()):
|
|
options = basic_options()
|
|
learn.run_training(0, options)
|
|
mock_init.assert_called_once_with(
|
|
trainer_factory_mock.return_value,
|
|
"results/ppo",
|
|
"ppo",
|
|
None,
|
|
True,
|
|
0,
|
|
sampler_manager_mock.return_value,
|
|
None,
|
|
)
|
|
handle_dir_mock.assert_called_once_with("results/ppo", False, False, None)
|
|
write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs")
|
|
write_run_options_mock.assert_called_once_with("results/ppo", options)
|
|
StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py
|
|
|
|
|
|
def test_bad_env_path():
|
|
with pytest.raises(UnityEnvironmentException):
|
|
factory = learn.create_environment_factory(
|
|
env_path="/foo/bar",
|
|
no_graphics=True,
|
|
seed=-1,
|
|
start_port=8000,
|
|
env_args=None,
|
|
log_folder="results/log_folder",
|
|
)
|
|
factory(worker_id=-1, side_channels=[])
|
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_YAML)
|
|
def test_commandline_args(mock_file):
|
|
# No args raises
|
|
# with pytest.raises(SystemExit):
|
|
# parse_command_line([])
|
|
# Test with defaults
|
|
opt = parse_command_line(["mytrainerpath"])
|
|
assert opt.behaviors == {}
|
|
assert opt.env_settings.env_path is None
|
|
assert opt.parameter_randomization is None
|
|
assert opt.checkpoint_settings.resume is False
|
|
assert opt.checkpoint_settings.inference is False
|
|
assert opt.checkpoint_settings.run_id == "ppo"
|
|
assert opt.env_settings.seed == -1
|
|
assert opt.env_settings.base_port == 5005
|
|
assert opt.env_settings.num_envs == 1
|
|
assert opt.engine_settings.no_graphics is False
|
|
assert opt.debug is False
|
|
assert opt.env_settings.env_args is None
|
|
|
|
full_args = [
|
|
"mytrainerpath",
|
|
"--env=./myenvfile",
|
|
"--resume",
|
|
"--inference",
|
|
"--run-id=myawesomerun",
|
|
"--seed=7890",
|
|
"--train",
|
|
"--base-port=4004",
|
|
"--num-envs=2",
|
|
"--no-graphics",
|
|
"--debug",
|
|
]
|
|
|
|
opt = parse_command_line(full_args)
|
|
assert opt.behaviors == {}
|
|
assert opt.env_settings.env_path == "./myenvfile"
|
|
assert opt.parameter_randomization is None
|
|
assert opt.checkpoint_settings.run_id == "myawesomerun"
|
|
assert opt.env_settings.seed == 7890
|
|
assert opt.env_settings.base_port == 4004
|
|
assert opt.env_settings.num_envs == 2
|
|
assert opt.engine_settings.no_graphics is True
|
|
assert opt.debug is True
|
|
assert opt.checkpoint_settings.inference is True
|
|
assert opt.checkpoint_settings.resume is True
|
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_PARAMETER_YAML)
|
|
def test_yaml_args(mock_file):
|
|
# Test with opts loaded from YAML
|
|
DetectDefault.non_default_args.clear()
|
|
opt = parse_command_line(["mytrainerpath"])
|
|
assert opt.behaviors == {}
|
|
assert opt.env_settings.env_path == "./oldenvfile"
|
|
assert opt.parameter_randomization is None
|
|
assert opt.checkpoint_settings.run_id == "uselessrun"
|
|
assert opt.env_settings.seed == 9870
|
|
assert opt.env_settings.base_port == 4001
|
|
assert opt.env_settings.num_envs == 4
|
|
assert opt.engine_settings.no_graphics is False
|
|
assert opt.debug is False
|
|
assert opt.env_settings.env_args is None
|
|
# Test that CLI overrides YAML
|
|
full_args = [
|
|
"mytrainerpath",
|
|
"--env=./myenvfile",
|
|
"--resume",
|
|
"--inference",
|
|
"--run-id=myawesomerun",
|
|
"--seed=7890",
|
|
"--train",
|
|
"--base-port=4004",
|
|
"--num-envs=2",
|
|
"--no-graphics",
|
|
"--debug",
|
|
]
|
|
|
|
opt = parse_command_line(full_args)
|
|
assert opt.behaviors == {}
|
|
assert opt.env_settings.env_path == "./myenvfile"
|
|
assert opt.parameter_randomization is None
|
|
assert opt.checkpoint_settings.run_id == "myawesomerun"
|
|
assert opt.env_settings.seed == 7890
|
|
assert opt.env_settings.base_port == 4004
|
|
assert opt.env_settings.num_envs == 2
|
|
assert opt.engine_settings.no_graphics is True
|
|
assert opt.debug is True
|
|
assert opt.checkpoint_settings.inference is True
|
|
assert opt.checkpoint_settings.resume is True
|
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_SAMPLER_CURRICULUM_YAML)
|
|
def test_sampler_configs(mock_file):
|
|
opt = parse_command_line(["mytrainerpath"])
|
|
assert opt.parameter_randomization == {"sampler1": "foo"}
|
|
assert len(opt.curriculum.keys()) == 2
|
|
|
|
|
|
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_YAML)
|
|
def test_env_args(mock_file):
|
|
full_args = [
|
|
"mytrainerpath",
|
|
"--env=./myenvfile",
|
|
"--env-args", # Everything after here will be grouped in a list
|
|
"--foo=bar",
|
|
"--blah",
|
|
"baz",
|
|
"100",
|
|
]
|
|
|
|
opt = parse_command_line(full_args)
|
|
assert opt.env_settings.env_args == ["--foo=bar", "--blah", "baz", "100"]
|