您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
96 行
2.9 KiB
96 行
2.9 KiB
import pytest
|
|
|
|
from mlagents.trainers.sampler_class import SamplerManager
|
|
from mlagents.trainers.sampler_class import (
|
|
UniformSampler,
|
|
MultiRangeUniformSampler,
|
|
GaussianSampler,
|
|
)
|
|
from mlagents.trainers.exception import TrainerError
|
|
|
|
|
|
def sampler_config_1():
|
|
return {
|
|
"mass": {"sampler-type": "uniform", "min_value": 5, "max_value": 10},
|
|
"gravity": {
|
|
"sampler-type": "multirange_uniform",
|
|
"intervals": [[8, 11], [15, 20]],
|
|
},
|
|
}
|
|
|
|
|
|
def check_value_in_intervals(val, intervals):
|
|
check_in_bounds = [a <= val <= b for a, b in intervals]
|
|
return any(check_in_bounds)
|
|
|
|
|
|
def test_sampler_config_1():
|
|
config = sampler_config_1()
|
|
sampler = SamplerManager(config)
|
|
|
|
assert sampler.is_empty() is False
|
|
assert isinstance(sampler.samplers["mass"], UniformSampler)
|
|
assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)
|
|
|
|
cur_sample = sampler.sample_all()
|
|
|
|
# Check uniform sampler for mass
|
|
assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
|
|
assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
|
|
assert config["mass"]["min_value"] <= cur_sample["mass"]
|
|
assert config["mass"]["max_value"] >= cur_sample["mass"]
|
|
|
|
# Check multirange_uniform sampler for gravity
|
|
assert sampler.samplers["gravity"].intervals == config["gravity"]["intervals"]
|
|
assert check_value_in_intervals(
|
|
cur_sample["gravity"], sampler.samplers["gravity"].intervals
|
|
)
|
|
|
|
|
|
def sampler_config_2():
|
|
return {"angle": {"sampler-type": "gaussian", "mean": 0, "st_dev": 1}}
|
|
|
|
|
|
def test_sampler_config_2():
|
|
config = sampler_config_2()
|
|
sampler = SamplerManager(config)
|
|
assert sampler.is_empty() is False
|
|
assert isinstance(sampler.samplers["angle"], GaussianSampler)
|
|
|
|
# Check angle gaussian sampler
|
|
assert sampler.samplers["angle"].mean == config["angle"]["mean"]
|
|
assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"]
|
|
|
|
|
|
def test_empty_samplers():
|
|
empty_sampler = SamplerManager({})
|
|
assert empty_sampler.is_empty()
|
|
empty_cur_sample = empty_sampler.sample_all()
|
|
assert empty_cur_sample == {}
|
|
|
|
none_sampler = SamplerManager(None)
|
|
assert none_sampler.is_empty()
|
|
none_cur_sample = none_sampler.sample_all()
|
|
assert none_cur_sample == {}
|
|
|
|
|
|
def incorrect_uniform_sampler():
|
|
# Do not specify required arguments to uniform sampler
|
|
return {"mass": {"sampler-type": "uniform", "min-value": 10}}
|
|
|
|
|
|
def incorrect_sampler_config():
|
|
# Do not specify 'sampler-type' key
|
|
return {"mass": {"min-value": 2, "max-value": 30}}
|
|
|
|
|
|
def test_incorrect_uniform_sampler():
|
|
config = incorrect_uniform_sampler()
|
|
with pytest.raises(TrainerError):
|
|
SamplerManager(config)
|
|
|
|
|
|
def test_incorrect_sampler():
|
|
config = incorrect_sampler_config()
|
|
with pytest.raises(TrainerError):
|
|
SamplerManager(config)
|