Andrew Cohen
5 年前
当前提交
fa5dae1a
共有 2 个文件被更改,包括 63 次插入 和 96 次删除
-
63ml-agents/mlagents/trainers/tests/test_settings.py
-
96ml-agents/mlagents/trainers/tests/test_sampler_class.py
|
|||
import pytest |
|||
|
|||
from mlagents.trainers.sampler_class import SamplerManager |
|||
from mlagents.trainers.sampler_class import ( |
|||
UniformSampler, |
|||
MultiRangeUniformSampler, |
|||
GaussianSampler, |
|||
) |
|||
from mlagents.trainers.exception import TrainerError |
|||
|
|||
|
|||
def sampler_config_1(): |
|||
return { |
|||
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10}, |
|||
"gravity": { |
|||
"sampler-type": "multirange_uniform", |
|||
"intervals": [[8, 11], [15, 20]], |
|||
}, |
|||
} |
|||
|
|||
|
|||
def check_value_in_intervals(val, intervals): |
|||
check_in_bounds = [a <= val <= b for a, b in intervals] |
|||
return any(check_in_bounds) |
|||
|
|||
|
|||
def test_sampler_config_1(): |
|||
config = sampler_config_1() |
|||
sampler = SamplerManager(config) |
|||
|
|||
assert sampler.is_empty() is False |
|||
assert isinstance(sampler.samplers["mass"], UniformSampler) |
|||
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler) |
|||
|
|||
cur_sample = sampler.sample_all() |
|||
|
|||
# Check uniform sampler for mass |
|||
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"] |
|||
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"] |
|||
assert config["mass"]["min_value"] <= cur_sample["mass"] |
|||
assert config["mass"]["max_value"] >= cur_sample["mass"] |
|||
|
|||
# Check multirange_uniform sampler for gravity |
|||
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"] |
|||
assert check_value_in_intervals( |
|||
cur_sample["gravity"], sampler.samplers["gravity"].intervals |
|||
) |
|||
|
|||
|
|||
def sampler_config_2(): |
|||
return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}} |
|||
|
|||
|
|||
def test_sampler_config_2(): |
|||
config = sampler_config_2() |
|||
sampler = SamplerManager(config) |
|||
assert sampler.is_empty() is False |
|||
assert isinstance(sampler.samplers["angle"], GaussianSampler) |
|||
|
|||
# Check angle gaussian sampler |
|||
assert sampler.samplers["angle"].mean == config["angle"]["mean"] |
|||
assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"] |
|||
|
|||
|
|||
def test_empty_samplers(): |
|||
empty_sampler = SamplerManager({}) |
|||
assert empty_sampler.is_empty() |
|||
empty_cur_sample = empty_sampler.sample_all() |
|||
assert empty_cur_sample == {} |
|||
|
|||
none_sampler = SamplerManager(None) |
|||
assert none_sampler.is_empty() |
|||
none_cur_sample = none_sampler.sample_all() |
|||
assert none_cur_sample == {} |
|||
|
|||
|
|||
def incorrect_uniform_sampler(): |
|||
# Do not specify required arguments to uniform sampler |
|||
return {"mass": {"sampler-type": "uniform", "min-value": 10}} |
|||
|
|||
|
|||
def incorrect_sampler_config(): |
|||
# Do not specify 'sampler-type' key |
|||
return {"mass": {"min-value": 2, "max-value": 30}} |
|||
|
|||
|
|||
def test_incorrect_uniform_sampler(): |
|||
config = incorrect_uniform_sampler() |
|||
with pytest.raises(TrainerError): |
|||
SamplerManager(config) |
|||
|
|||
|
|||
def test_incorrect_sampler(): |
|||
config = incorrect_sampler_config() |
|||
with pytest.raises(TrainerError): |
|||
SamplerManager(config) |
撰写
预览
正在加载...
取消
保存
Reference in new issue