浏览代码

Enable generalization training (#2232)

* Add Sampler and SamplerManager
* Enable resampling of reset parameters during training
* Documentation for Sampler and example YAML configuration file
/develop-generalizationTraining-TrainerController
Ervin T 5 年前
当前提交
a46f3faa
共有 14 个文件被更改,包括 1857 次插入43 次删除
  1. 4
      docs/Training-ML-Agents.md
  2. 8
      ml-agents-envs/mlagents/envs/exception.py
  3. 2
      ml-agents/mlagents/trainers/exception.py
  4. 67
      ml-agents/mlagents/trainers/learn.py
  5. 3
      ml-agents/mlagents/trainers/tests/test_environments/test_simple.py
  6. 13
      ml-agents/mlagents/trainers/tests/test_learn.py
  7. 19
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  8. 81
      ml-agents/mlagents/trainers/trainer_controller.py
  9. 16
      config/generalize_test.yaml
  10. 124
      docs/Training-Generalization-Learning.md
  11. 850
      docs/images/3dball_big.png
  12. 482
      docs/images/3dball_small.png
  13. 134
      ml-agents-envs/mlagents/envs/sampler_class.py
  14. 97
      ml-agents-envs/mlagents/envs/tests/test_sampler_class.py

4
docs/Training-ML-Agents.md


* `--curriculum=<file>` – Specify a curriculum JSON file for defining the
lessons for curriculum training. See [Curriculum
Training](Training-Curriculum-Learning.md) for more information.
* `--sampler=<file>` - Specify a sampler YAML file for defining the
sampler for generalization training. See [Generalization
Training](Training-Generalization-Learning.md) for more information.
* `--keep-checkpoints=<n>` – Specify the maximum number of model checkpoints to
keep. Checkpoints are saved after the number of steps specified by the
`save-freq` option. Once the maximum number of checkpoints has been reached,

* [Training with PPO](Training-PPO.md)
* [Using Recurrent Neural Networks](Feature-Memory.md)
* [Training with Curriculum Learning](Training-Curriculum-Learning.md)
* [Training with Generalization](Training-Generalization-Learning.md)
* [Training with Imitation Learning](Training-Imitation-Learning.md)
You can also compare the

8
ml-agents-envs/mlagents/envs/exception.py


pass
class SamplerException(UnityException):
"""
Related to errors with the sampler actions.
"""
pass
class UnityTimeOutException(UnityException):
"""
Related to errors with communication timeouts.

2
ml-agents/mlagents/trainers/exception.py


"""
Any error related to the configuration of a metacurriculum.
"""
pass

67
ml-agents/mlagents/trainers/learn.py


from mlagents.trainers.exception import TrainerError
from mlagents.trainers import MetaCurriculumError, MetaCurriculum
from mlagents.envs import UnityEnvironment
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.exception import UnityEnvironmentException, SamplerException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.subprocess_env_manager import SubprocessEnvManager

fast_simulation = not bool(run_options["--slow"])
no_graphics = run_options["--no-graphics"]
trainer_config_path = run_options["<trainer-config-path>"]
sampler_file_path = (
run_options["--sampler"] if run_options["--sampler"] != "None" else None
)
# Recognize and use docker volume if one is passed as an argument
if not docker_target_name:
model_path = "./models/{run_id}-{sub_id}".format(run_id=run_id, sub_id=sub_id)

)
env = SubprocessEnvManager(env_factory, num_envs)
maybe_meta_curriculum = try_create_meta_curriculum(curriculum_folder, env)
sampler_manager, resampling_interval = create_sampler_manager(
sampler_file_path, env.reset_parameters
)
# Create controller and begin training.
tc = TrainerController(

lesson,
run_seed,
fast_simulation,
sampler_manager,
resampling_interval,
)
# Signal that environment has been launched.

tc.start_learning(env, trainer_config)
def create_sampler_manager(sampler_file_path, env_reset_params):
sampler_config = None
resample_interval = None
if sampler_file_path is not None:
sampler_config = load_config(sampler_file_path)
if ("resampling-interval") in sampler_config:
# Filter arguments that do not exist in the environment
resample_interval = sampler_config.pop("resampling-interval")
if (resample_interval <= 0) or (not isinstance(resample_interval, int)):
raise SamplerException(
"Specified resampling-interval is not valid. Please provide"
" a positive integer value for resampling-interval"
)
else:
raise SamplerException(
"Resampling interval was not specified in the sampler file."
" Please specify it with the 'resampling-interval' key in the sampler config file."
)
sampler_manager = SamplerManager(sampler_config)
return sampler_manager, resample_interval
def try_create_meta_curriculum(

mlagents-learn --help
Options:
--env=<file> Name of the Unity executable [default: None].
--curriculum=<directory> Curriculum json directory for environment [default: None].
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5].
--lesson=<n> Start learning from this lesson [default: 0].
--load Whether to load the model or randomly initialize [default: False].
--run-id=<path> The directory name for model and summary statistics [default: ppo].
--num-runs=<n> Number of concurrent training sessions [default: 1].
--save-freq=<n> Frequency at which to save model [default: 50000].
--seed=<n> Random seed used for training [default: -1].
--slow Whether to run the game at training speed [default: False].
--train Whether to train model, or only run inference [default: False].
--base-port=<n> Base port for environment communication [default: 5005].
--num-envs=<n> Number of parallel environments to use for training [default: 1]
--docker-target-name=<dt> Docker volume to store training-specific files [default: None].
--no-graphics Whether to run the environment in no-graphics mode [default: False].
--debug Whether to run ML-Agents in debug mode with detailed logging [default: False].
--env=<file> Name of the Unity executable [default: None].
--curriculum=<directory> Curriculum json directory for environment [default: None].
--sampler=<file> Reset parameter yaml file for environment [default: None].
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5].
--lesson=<n> Start learning from this lesson [default: 0].
--load Whether to load the model or randomly initialize [default: False].
--run-id=<path> The directory name for model and summary statistics [default: ppo].
--num-runs=<n> Number of concurrent training sessions [default: 1].
--save-freq=<n> Frequency at which to save model [default: 50000].
--seed=<n> Random seed used for training [default: -1].
--slow Whether to run the game at training speed [default: False].
--train Whether to train model, or only run inference [default: False].
--base-port=<n> Base port for environment communication [default: 5005].
--num-envs=<n> Number of parallel environments to use for training [default: 1]
--docker-target-name=<dt> Docker volume to store training-specific files [default: None].
--no-graphics Whether to run the environment in no-graphics mode [default: False].
--debug Whether to run ML-Agents in debug mode with detailed logging [default: False].
"""
options = docopt(_USAGE)

3
ml-agents/mlagents/trainers/tests/test_environments/test_simple.py


from mlagents.envs import BrainInfo, AllBrainInfo, BrainParameters
from mlagents.envs.communicator_objects import AgentInfoProto
from mlagents.envs.simple_env_manager import SimpleEnvManager
from mlagents.envs.sampler_class import SamplerManager
BRAIN_NAME = __name__

lesson=None,
training_seed=1337,
fast_simulation=True,
sampler_manager=SamplerManager(None),
resampling_interval=None,
)
# Begin training

13
ml-agents/mlagents/trainers/tests/test_learn.py


"--no-graphics": False,
"<trainer-config-path>": "basic_path",
"--debug": False,
"--sampler": None,
@patch("mlagents.trainers.learn.SamplerManager")
def test_run_training(load_config, create_environment_factory, subproc_env_mock):
def test_run_training(
load_config, create_environment_factory, subproc_env_mock, sampler_manager_mock
):
mock_env = MagicMock()
mock_env.external_brain_names = []
mock_env.academy_name = "TestAcademyName"

0,
0,
True,
sampler_manager_mock.return_value,
None,
@patch("mlagents.trainers.learn.SamplerManager")
def test_docker_target_path(load_config, create_environment_factory, subproc_env_mock):
def test_docker_target_path(
load_config, create_environment_factory, subproc_env_mock, sampler_manager_mock
):
mock_env = MagicMock()
mock_env.external_brain_names = []
mock_env.academy_name = "TestAcademyName"

19
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


from mlagents.trainers.bc.online_trainer import OnlineBCTrainer
from mlagents.envs.subprocess_env_manager import StepInfo
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.envs.sampler_class import SamplerManager
@pytest.fixture

lesson=None,
training_seed=99,
fast_simulation=True,
sampler_manager=SamplerManager(None),
resampling_interval=None,
)

seed = 27
TrainerController("", "", "1", 1, None, True, False, False, None, seed, True)
TrainerController(
"",
"",
"1",
1,
None,
True,
False,
False,
None,
seed,
True,
SamplerManager(None),
None,
)
numpy_random_seed.assert_called_with(seed)
tensorflow_set_seed.assert_called_with(seed)

81
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.envs.env_manager import StepInfo
from mlagents.envs.env_manager import EnvManager
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.timers import hierarchical_timer, get_timer_tree, timed
from mlagents.trainers import Trainer, TrainerMetrics
from mlagents.trainers.ppo.trainer import PPOTrainer

from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.subprocess_env_manager import SubprocessEnvManager
class TrainerController(object):

lesson: Optional[int],
training_seed: int,
fast_simulation: bool,
sampler_manager: SamplerManager,
resampling_interval: Optional[int],
):
"""
:param model_path: Path to save the model.

:param keep_checkpoints: How many model checkpoints to keep.
:param lesson: Start learning from this lesson.
:param training_seed: Seed to use for Numpy and Tensorflow random number generation.
:param sampler_manager: SamplerManager object handles samplers for resampling the reset parameters.
:param resampling_interval: Specifies number of simulation steps after which reset parameters are resampled.
"""
self.model_path = model_path

self.fast_simulation = fast_simulation
np.random.seed(self.seed)
tf.set_random_seed(self.seed)
self.sampler_manager = sampler_manager
self.resampling_interval = resampling_interval
def _get_measure_vals(self):
brain_names_to_measure_vals = {}

run_id=self.run_id,
)
elif trainer_parameters_dict[brain_name]["trainer"] == "ppo":
self.trainers[brain_name] = PPOTrainer(
brain=external_brains[brain_name],
reward_buff_cap=self.meta_curriculum.brains_to_curriculums[
# Find lesson length based on the form of learning
if self.meta_curriculum:
lesson_length = self.meta_curriculum.brains_to_curriculums[
if self.meta_curriculum
else 1,
else:
lesson_length = 1
self.trainers[brain_name] = PPOTrainer(
brain=external_brains[brain_name],
reward_buff_cap=lesson_length,
trainer_parameters=trainer_parameters_dict[brain_name],
training=self.train_model,
load=self.load_model,

A Data structure corresponding to the initial reset state of the
environment.
"""
if self.meta_curriculum is not None:
return env.reset(
train_mode=self.fast_simulation,
config=self.meta_curriculum.get_config(),
)
else:
return env.reset(train_mode=self.fast_simulation)
sampled_reset_param = self.sampler_manager.sample_all()
new_meta_curriculum_config = (
self.meta_curriculum.get_config() if self.meta_curriculum else {}
)
sampled_reset_param.update(new_meta_curriculum_config)
return env.reset(train_mode=self.fast_simulation, config=sampled_reset_param)
def _should_save_model(self, global_step: int) -> bool:
return (

n_steps = self.advance(env_manager)
for i in range(n_steps):
global_step += 1
self.reset_env_if_ready(env_manager, global_step)
if self._should_save_model(global_step):
# Save Tensorflow model
self._save_model()

self._export_graph()
self._write_timing_tree()
@timed
def advance(self, env: EnvManager) -> int:
def end_trainer_episodes(
self, env: BaseUnityEnvironment, lessons_incremented: Dict[str, bool]
) -> None:
self._reset_env(env)
# Reward buffers reset takes place only for curriculum learning
# else no reset.
for brain_name, trainer in self.trainers.items():
trainer.end_episode()
for brain_name, changed in lessons_incremented.items():
if changed:
self.trainers[brain_name].reward_buffer.clear()
def reset_env_if_ready(self, env: BaseUnityEnvironment, steps: int) -> None:
if self.meta_curriculum:
# Get the sizes of the reward buffers.
reward_buff_sizes = {

lessons_incremented = self.meta_curriculum.increment_lessons(
self._get_measure_vals(), reward_buff_sizes=reward_buff_sizes
)
# If any lessons were incremented or the environment is
# ready to be reset
if any(lessons_incremented.values()):
self._reset_env(env)
for brain_name, trainer in self.trainers.items():
trainer.end_episode()
for brain_name, changed in lessons_incremented.items():
if changed:
self.trainers[brain_name].reward_buffer.clear()
else:
lessons_incremented = {}
# If any lessons were incremented or the environment is
# ready to be reset
meta_curriculum_reset = any(lessons_incremented.values())
# Check if we are performing generalization training and we have finished the
# specified number of steps for the lesson
generalization_reset = (
not self.sampler_manager.is_empty()
and (steps != 0)
and (self.resampling_interval)
and (steps % self.resampling_interval == 0)
)
if meta_curriculum_reset or generalization_reset:
self.end_trainer_episodes(env, lessons_incremented)
@timed
def advance(self, env: SubprocessEnvManager) -> int:
with hierarchical_timer("env_step"):
time_start_step = time()
new_step_infos = env.step()

16
config/generalize_test.yaml


resampling-interval: 5000
mass:
sampler-type: "uniform"
min_value: 0.5
max_value: 10
gravity:
sampler-type: "uniform"
min_value: 7
max_value: 12
scale:
sampler-type: "uniform"
min_value: 0.75
max_value: 3

124
docs/Training-Generalization-Learning.md


# Training Generalized Reinforcement Learning Agents
Reinforcement learning has a rather unique setup as opposed to supervised and
unsupervised learning. Agents here are trained and tested on the same exact
environment, which is analogous to a model being trained and tested on an
identical dataset in supervised learning! This setting results in overfitting;
the inability of the agent to generalize to slight tweaks or variations in the
environment. This is problematic in instances when environments are randomly
instantiated with varying properties. To make agents robust, one approach is to
train an agent over multiple variations of the environment. The agent is
trained in this approach with the intent that it learns to adapt its performance
to future unseen variations of the environment.
Ball scale of 0.5 | Ball scale of 4
:-------------------------:|:-------------------------:
![](images/3dball_small.png) | ![](images/3dball_big.png)
_Variations of the 3D Ball environment._
To vary environments, we first decide what parameters to vary in an
environment. These parameters are known as `Reset Parameters`. In the 3D ball
environment example displayed in the figure above, the reset parameters are `gravity`, `ball_mass` and `ball_scale`.
## How-to
For generalization training, we need to provide a way to modify the environment
by supplying a set of reset parameters, and vary them over time. This provision
can be done either deterministically or randomly.
This is done by assigning each reset parameter a sampler, which samples a reset
parameter value (such as a uniform sampler). If a sampler isn't provided for a
reset parameter, the parameter maintains the default value throughout the
training, remaining unchanged. The samplers for all the reset parameters are
handled by a **Sampler Manager**, which also handles the generation of new
values for the reset parameters when needed.
To setup the Sampler Manager, we setup a YAML file that specifies how we wish to
generate new samples. In this file, we specify the samplers and the
`resampling-duration` (number of simulation steps after which reset parameters are
resampled). Below is an example of a sampler file for the 3D ball environment.
```yaml
episode-length: 5000
mass:
sampler-type: "uniform"
min_value: 0.5
max_value: 10
gravity:
sampler-type: "multirange_uniform"
intervals: [[7, 10], [15, 20]]
scale:
sampler-type: "uniform"
min_value: 0.75
max_value: 3
```
* `resampling-duration` (int) - Specifies the number of steps for agent to
train under a particular environment configuration before resetting the
environment with a new sample of reset parameters.
* `parameter_name` - Name of the reset parameter. This should match the name
specified in the academy of the intended environment for which the agent is
being trained. If a parameter specified in the file doesn't exist in the
environment, then this specification will be ignored.
* `sampler-type` - Specify the sampler type to use for the reset parameter.
This is a string that should exist in the `Sampler Factory` (explained
below).
* `sub-arguments` - Specify the characteristic parameters for the sampler.
In the example sampler file above, this would correspond to the `intervals`
key under the `multirange_uniform` sampler for the gravity reset parameter.
The key name should match the name of the corresponding argument in the sampler definition. (Look at defining a new sampler method)
The sampler manager allocates a sampler for a reset parameter by using the *Sampler Factory*, which maintains a dictionary mapping of string keys to sampler objects. The available samplers to be used for reset parameter resampling is as available in the Sampler Factory.
The implementation of the samplers can be found at `ml-agents-envs/mlagents/envs/sampler_class.py`.
### Defining a new sampler method
Custom sampling techniques must inherit from the *Sampler* base class (included in the `sampler_class` file) and preserve the interface. Once the class for the required method is specified, it must be registered in the Sampler Factory.
This can be done by subscribing to the *register_sampler* method of the SamplerFactory. The command is as follows:
`SamplerFactory.register_sampler(*custom_sampler_string_key*, *custom_sampler_object*)`
Once the Sampler Factory reflects the new register, the custom sampler can be used for resampling reset parameter. For demonstration, lets say our sampler was implemented as below, and we register the `CustomSampler` class with the string `custom-sampler` in the Sampler Factory.
```python
class CustomSampler(Sampler):
def __init__(self, argA, argB, argC):
self.possible_vals = [argA, argB, argC]
def sample_all(self):
return np.random.choice(self.possible_vals)
```
Now we need to specify this sampler in the sampler file. Lets say we wish to use this sampler for the reset parameter *mass*; the sampler file would specify the same for mass as the following (any order of the subarguments is valid).
```yaml
mass:
sampler-type: "custom-sampler"
argB: 1
argA: 2
argC: 3
```
With the sampler file setup, we can proceed to train our agent as explained in the next section.
### Training with Generalization Learning
We first begin with setting up the sampler file. After the sampler file is defined and configured, we proceed by launching `mlagents-learn` and specify our configured sampler file with the `--sampler` flag. To demonstrate, if we wanted to train a 3D ball agent with generalization using the `config/generalization-test.yaml` sampling setup, we can run
```sh
mlagents-learn config/trainer_config.yaml --sampler=config/generalize_test.yaml --run-id=3D-Ball-generalization --train
```
We can observe progress and metrics via Tensorboard.

850
docs/images/3dball_big.png

之前 之后
宽度: 1372  |  高度: 812  |  大小: 196 KiB

482
docs/images/3dball_small.png

之前 之后
宽度: 1372  |  高度: 816  |  大小: 139 KiB

134
ml-agents-envs/mlagents/envs/sampler_class.py


import numpy as np
from typing import *
from functools import *
from collections import OrderedDict
from abc import ABC, abstractmethod
from .exception import SamplerException
class Sampler(ABC):
@abstractmethod
def sample_parameter(self) -> float:
pass
class UniformSampler(Sampler):
"""
Uniformly draws a single sample in the range [min_value, max_value).
"""
def __init__(
self, min_value: Union[int, float], max_value: Union[int, float], **kwargs
) -> None:
self.min_value = min_value
self.max_value = max_value
def sample_parameter(self) -> float:
return np.random.uniform(self.min_value, self.max_value)
class MultiRangeUniformSampler(Sampler):
"""
Draws a single sample uniformly from the intervals provided. The sampler
first picks an interval based on a weighted selection, with the weights
assigned to an interval based on its range. After picking the range,
it proceeds to pick a value uniformly in that range.
"""
def __init__(self, intervals: List[List[Union[int, float]]], **kwargs) -> None:
self.intervals = intervals
# Measure the length of the intervals
interval_lengths = [abs(x[1] - x[0]) for x in self.intervals]
cum_interval_length = sum(interval_lengths)
# Assign weights to an interval proportionate to the interval size
self.interval_weights = [x / cum_interval_length for x in interval_lengths]
def sample_parameter(self) -> float:
cur_min, cur_max = self.intervals[
np.random.choice(len(self.intervals), p=self.interval_weights)
]
return np.random.uniform(cur_min, cur_max)
class GaussianSampler(Sampler):
"""
Draw a single sample value from a normal (gaussian) distribution.
This sampler is characterized by the mean and the standard deviation.
"""
def __init__(
self, mean: Union[float, int], st_dev: Union[float, int], **kwargs
) -> None:
self.mean = mean
self.st_dev = st_dev
def sample_parameter(self) -> float:
return np.random.normal(self.mean, self.st_dev)
class SamplerFactory:
"""
Maintain a directory of all samplers available.
Add new samplers using the register_sampler method.
"""
NAME_TO_CLASS = {
"uniform": UniformSampler,
"gaussian": GaussianSampler,
"multirange_uniform": MultiRangeUniformSampler,
}
@staticmethod
def register_sampler(name: str, sampler_cls: Type[Sampler]) -> None:
SamplerFactory.NAME_TO_CLASS[name] = sampler_cls
@staticmethod
def init_sampler_class(name: str, params: Dict[str, Any]):
if name not in SamplerFactory.NAME_TO_CLASS:
raise SamplerException(
name + " sampler is not registered in the SamplerFactory."
" Use the register_sample method to register the string"
" associated to your sampler in the SamplerFactory."
)
sampler_cls = SamplerFactory.NAME_TO_CLASS[name]
try:
return sampler_cls(**params)
except TypeError:
raise SamplerException(
"The sampler class associated to the " + name + " key in the factory "
"was not provided the required arguments. Please ensure that the sampler "
"config file consists of the appropriate keys for this sampler class."
)
class SamplerManager:
def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
self.reset_param_dict = reset_param_dict if reset_param_dict else {}
assert isinstance(self.reset_param_dict, dict)
self.samplers: Dict[str, Sampler] = {}
for param_name, cur_param_dict in self.reset_param_dict.items():
if "sampler-type" not in cur_param_dict:
raise SamplerException(
"'sampler_type' argument hasn't been supplied for the {0} parameter".format(
param_name
)
)
sampler_name = cur_param_dict.pop("sampler-type")
param_sampler = SamplerFactory.init_sampler_class(
sampler_name, cur_param_dict
)
self.samplers[param_name] = param_sampler
def is_empty(self) -> bool:
"""
Check for if sampler_manager is empty.
"""
return not bool(self.samplers)
def sample_all(self) -> Dict[str, float]:
res = {}
for param_name, param_sampler in list(self.samplers.items()):
res[param_name] = param_sampler.sample_parameter()
return res

97
ml-agents-envs/mlagents/envs/tests/test_sampler_class.py


from math import isclose
import pytest
from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.sampler_class import (
UniformSampler,
MultiRangeUniformSampler,
GaussianSampler,
)
from mlagents.envs.exception import UnityException
def sampler_config_1():
return {
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10},
"gravity": {
"sampler-type": "multirange_uniform",
"intervals": [[8, 11], [15, 20]],
},
}
def check_value_in_intervals(val, intervals):
check_in_bounds = [a <= val <= b for a, b in intervals]
return any(check_in_bounds)
def test_sampler_config_1():
config = sampler_config_1()
sampler = SamplerManager(config)
assert sampler.is_empty() is False
assert isinstance(sampler.samplers["mass"], UniformSampler)
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)
cur_sample = sampler.sample_all()
# Check uniform sampler for mass
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
assert config["mass"]["min_value"] <= cur_sample["mass"]
assert config["mass"]["max_value"] >= cur_sample["mass"]
# Check multirange_uniform sampler for gravity
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"]
assert check_value_in_intervals(
cur_sample["gravity"], sampler.samplers["gravity"].intervals
)
def sampler_config_2():
return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}}
def test_sampler_config_2():
config = sampler_config_2()
sampler = SamplerManager(config)
assert sampler.is_empty() is False
assert isinstance(sampler.samplers["angle"], GaussianSampler)
# Check angle gaussian sampler
assert sampler.samplers["angle"].mean == config["angle"]["mean"]
assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"]
def test_empty_samplers():
empty_sampler = SamplerManager({})
assert empty_sampler.is_empty()
empty_cur_sample = empty_sampler.sample_all()
assert empty_cur_sample == {}
none_sampler = SamplerManager(None)
assert none_sampler.is_empty()
none_cur_sample = none_sampler.sample_all()
assert none_cur_sample == {}
def incorrect_uniform_sampler():
# Do not specify required arguments to uniform sampler
return {"mass": {"sampler-type": "uniform", "min-value": 10}}
def incorrect_sampler_config():
# Do not specify 'sampler-type' key
return {"mass": {"min-value": 2, "max-value": 30}}
def test_incorrect_uniform_sampler():
config = incorrect_uniform_sampler()
with pytest.raises(UnityException):
SamplerManager(config)
def test_incorrect_sampler():
config = incorrect_sampler_config()
with pytest.raises(UnityException):
SamplerManager(config)
正在加载...
取消
保存