浏览代码

fix tests

/sampler-refactor-copy
Andrew Cohen 5 年前
当前提交
f76780f1
共有 3 个文件被更改,包括 6 次插入20 次删除
  1. 18
      ml-agents/mlagents/trainers/tests/test_learn.py
  2. 3
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  3. 5
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py

18
ml-agents/mlagents/trainers/tests/test_learn.py


from mlagents.trainers.cli_utils import DetectDefault
from mlagents_envs.exception import UnityEnvironmentException
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.settings import UniformSettings
def basic_options(extra_args=None):

MOCK_SAMPLER_CURRICULUM_YAML = """
parameter_randomization:
sampler1: foo
sampler1:
uniform:
min_value: 0.2
curriculum:
behavior1:

@patch("mlagents.trainers.learn.write_run_options")
@patch("mlagents.trainers.learn.handle_existing_directories")
@patch("mlagents.trainers.learn.TrainerFactory")
@patch("mlagents.trainers.learn.SamplerManager")
@patch("mlagents.trainers.learn.SubprocessEnvManager")
@patch("mlagents.trainers.learn.create_environment_factory")
@patch("mlagents.trainers.settings.load_config")

subproc_env_mock,
sampler_manager_mock,
trainer_factory_mock,
handle_dir_mock,
write_run_options_mock,

options = basic_options()
learn.run_training(0, options)
mock_init.assert_called_once_with(
trainer_factory_mock.return_value,
"results/ppo",
"ppo",
None,
True,
0,
sampler_manager_mock.return_value,
None,
trainer_factory_mock.return_value, "results/ppo", "ppo", None, True, 0
)
handle_dir_mock.assert_called_once_with("results/ppo", False, False, None)
write_timing_tree_mock.assert_called_once_with("results/ppo/run_logs")

@patch("builtins.open", new_callable=mock_open, read_data=MOCK_SAMPLER_CURRICULUM_YAML)
def test_sampler_configs(mock_file):
opt = parse_command_line(["mytrainerpath"])
assert opt.parameter_randomization == {"sampler1": "foo"}
assert isinstance(opt.parameter_randomization["sampler1"], UniformSettings)
assert len(opt.curriculum.keys()) == 2

3
ml-agents/mlagents/trainers/tests/test_simple_rl.py


from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.simple_env_manager import SimpleEnvManager
from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.demo_loader import write_demo
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary
from mlagents.trainers.settings import (

meta_curriculum=meta_curriculum,
train=True,
training_seed=seed,
sampler_manager=SamplerManager(None),
resampling_interval=None,
)
# Begin training

5
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


from mlagents.tf_utils import tf
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.sampler_class import SamplerManager
@pytest.fixture

meta_curriculum=None,
train=True,
training_seed=99,
sampler_manager=SamplerManager({}),
resampling_interval=None,
)

meta_curriculum=None,
train=True,
training_seed=seed,
sampler_manager=SamplerManager({}),
resampling_interval=None,
)
numpy_random_seed.assert_called_with(seed)
tensorflow_set_seed.assert_called_with(seed)

正在加载...
取消
保存