浏览代码

tests for settings

/sampler-refactor-copy
Andrew Cohen 4 年前
当前提交
fa5dae1a
共有 2 个文件被更改,包括 63 次插入96 次删除
  1. 63
      ml-agents/mlagents/trainers/tests/test_settings.py
  2. 96
      ml-agents/mlagents/trainers/tests/test_sampler_class.py

63
ml-agents/mlagents/trainers/tests/test_settings.py


RewardSignalType,
RewardSignalSettings,
CuriositySettings,
ParameterRandomizationSettings,
UniformSettings,
GaussianSettings,
MultiRangeUniformSettings,
TrainerType,
strict_to_cls,
)

RewardSignalSettings.structure(
"notadict", Dict[RewardSignalType, RewardSignalSettings]
)
def test_parameter_randomization_structure():
"""
Tests the ParameterRandomizationSettings structure method and all validators.
"""
parameter_randomization_dict = {
"mass": {"uniform": {"min_value": 1.0, "max_value": 2.0}},
"scale": {"gaussian": {"mean": 1.0, "st_dev": 2.0}},
"length": {"multirangeuniform": {"intervals": [[1.0, 2.0], [3.0, 4.0]]}},
}
parameter_randomization_distributions = ParameterRandomizationSettings.structure(
parameter_randomization_dict, Dict[str, ParameterRandomizationSettings]
)
assert isinstance(parameter_randomization_distributions["mass"], UniformSettings)
assert isinstance(parameter_randomization_distributions["scale"], GaussianSettings)
assert isinstance(
parameter_randomization_distributions["length"], MultiRangeUniformSettings
)
# Check invalid distribution type
invalid_distribution_dict = {"mass": {"beta": {"alpha": 1.0, "beta": 2.0}}}
with pytest.raises(ValueError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check min less than max in uniform
invalid_distribution_dict = {
"mass": {"uniform": {"min_value": 2.0, "max_value": 1.0}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check min less than max in multirange
invalid_distribution_dict = {
"mass": {"multirangeuniform": {"intervals": [[2.0, 1.0]]}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check multirange has valid intervals
invalid_distribution_dict = {
"mass": {"multirangeuniform": {"intervals": [[1.0, 2.0], [3.0]]}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check non-Dict input
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
"notadict", Dict[str, ParameterRandomizationSettings]
)

96
ml-agents/mlagents/trainers/tests/test_sampler_class.py


import pytest
from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.sampler_class import (
UniformSampler,
MultiRangeUniformSampler,
GaussianSampler,
)
from mlagents.trainers.exception import TrainerError
def sampler_config_1():
return {
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10},
"gravity": {
"sampler-type": "multirange_uniform",
"intervals": [[8, 11], [15, 20]],
},
}
def check_value_in_intervals(val, intervals):
check_in_bounds = [a <= val <= b for a, b in intervals]
return any(check_in_bounds)
def test_sampler_config_1():
config = sampler_config_1()
sampler = SamplerManager(config)
assert sampler.is_empty() is False
assert isinstance(sampler.samplers["mass"], UniformSampler)
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)
cur_sample = sampler.sample_all()
# Check uniform sampler for mass
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
assert config["mass"]["min_value"] <= cur_sample["mass"]
assert config["mass"]["max_value"] >= cur_sample["mass"]
# Check multirange_uniform sampler for gravity
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"]
assert check_value_in_intervals(
cur_sample["gravity"], sampler.samplers["gravity"].intervals
)
def sampler_config_2():
return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}}
def test_sampler_config_2():
config = sampler_config_2()
sampler = SamplerManager(config)
assert sampler.is_empty() is False
assert isinstance(sampler.samplers["angle"], GaussianSampler)
# Check angle gaussian sampler
assert sampler.samplers["angle"].mean == config["angle"]["mean"]
assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"]
def test_empty_samplers():
empty_sampler = SamplerManager({})
assert empty_sampler.is_empty()
empty_cur_sample = empty_sampler.sample_all()
assert empty_cur_sample == {}
none_sampler = SamplerManager(None)
assert none_sampler.is_empty()
none_cur_sample = none_sampler.sample_all()
assert none_cur_sample == {}
def incorrect_uniform_sampler():
# Do not specify required arguments to uniform sampler
return {"mass": {"sampler-type": "uniform", "min-value": 10}}
def incorrect_sampler_config():
# Do not specify 'sampler-type' key
return {"mass": {"min-value": 2, "max-value": 30}}
def test_incorrect_uniform_sampler():
config = incorrect_uniform_sampler()
with pytest.raises(TrainerError):
SamplerManager(config)
def test_incorrect_sampler():
config = incorrect_sampler_config()
with pytest.raises(TrainerError):
SamplerManager(config)
正在加载...
取消
保存