Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

214 行
7.5 KiB

import attr
import pytest
from typing import Dict
from mlagents.trainers.settings import (
RunOptions,
TrainerSettings,
PPOSettings,
SACSettings,
RewardSignalType,
RewardSignalSettings,
CuriositySettings,
ParameterRandomizationSettings,
UniformSettings,
GaussianSettings,
MultiRangeUniformSettings,
TrainerType,
strict_to_cls,
)
from mlagents.trainers.exception import TrainerConfigError
def check_if_different(testobj1: object, testobj2: object) -> None:
assert testobj1 is not testobj2
if attr.has(testobj1.__class__) and attr.has(testobj2.__class__):
for key, val in attr.asdict(testobj1, recurse=False).items():
if isinstance(val, dict) or isinstance(val, list) or attr.has(val):
# Note: this check doesn't check the contents of mutables.
check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
def test_is_new_instance():
"""
Verify that every instance of RunOptions() and its subclasses
is a new instance (i.e. all factory methods are used properly.)
"""
check_if_different(RunOptions(), RunOptions())
check_if_different(TrainerSettings(), TrainerSettings())
def test_no_configuration():
"""
Verify that a new config will have a PPO trainer with extrinsic rewards.
"""
blank_runoptions = RunOptions()
assert isinstance(blank_runoptions.behaviors["test"], TrainerSettings)
assert isinstance(blank_runoptions.behaviors["test"].hyperparameters, PPOSettings)
assert (
RewardSignalType.EXTRINSIC in blank_runoptions.behaviors["test"].reward_signals
)
def test_strict_to_cls():
"""
Test strict structuring method.
"""
@attr.s(auto_attribs=True)
class TestAttrsClass:
field1: int = 0
field2: str = "test"
correct_dict = {"field1": 1, "field2": "test2"}
assert strict_to_cls(correct_dict, TestAttrsClass) == TestAttrsClass(**correct_dict)
incorrect_dict = {"field3": 1, "field2": "test2"}
with pytest.raises(TrainerConfigError):
strict_to_cls(incorrect_dict, TestAttrsClass)
with pytest.raises(TrainerConfigError):
strict_to_cls("non_dict_input", TestAttrsClass)
def test_trainersettings_structure():
"""
Test structuring method for TrainerSettings
"""
trainersettings_dict = {
"trainer_type": "sac",
"hyperparameters": {"batch_size": 1024},
"max_steps": 1.0,
"reward_signals": {"curiosity": {"encoding_size": 64}},
}
trainer_settings = TrainerSettings.structure(trainersettings_dict, TrainerSettings)
assert isinstance(trainer_settings.hyperparameters, SACSettings)
assert trainer_settings.trainer_type == TrainerType.SAC
assert isinstance(trainer_settings.max_steps, int)
assert RewardSignalType.CURIOSITY in trainer_settings.reward_signals
# Check invalid trainer type
with pytest.raises(ValueError):
trainersettings_dict = {
"trainer_type": "puppo",
"hyperparameters": {"batch_size": 1024},
"max_steps": 1.0,
}
TrainerSettings.structure(trainersettings_dict, TrainerSettings)
# Check invalid hyperparameter
with pytest.raises(TrainerConfigError):
trainersettings_dict = {
"trainer_type": "ppo",
"hyperparameters": {"notahyperparam": 1024},
"max_steps": 1.0,
}
TrainerSettings.structure(trainersettings_dict, TrainerSettings)
# Check non-dict
with pytest.raises(TrainerConfigError):
TrainerSettings.structure("notadict", TrainerSettings)
# Check hyperparameters specified but trainer type left as default.
# This shouldn't work as you could specify non-PPO hyperparameters.
with pytest.raises(TrainerConfigError):
trainersettings_dict = {"hyperparameters": {"batch_size": 1024}}
TrainerSettings.structure(trainersettings_dict, TrainerSettings)
def test_reward_signal_structure():
"""
Tests the RewardSignalSettings structure method. This one is special b/c
it takes in a Dict[RewardSignalType, RewardSignalSettings].
"""
reward_signals_dict = {
"extrinsic": {"strength": 1.0},
"curiosity": {"strength": 1.0},
}
reward_signals = RewardSignalSettings.structure(
reward_signals_dict, Dict[RewardSignalType, RewardSignalSettings]
)
assert isinstance(reward_signals[RewardSignalType.EXTRINSIC], RewardSignalSettings)
assert isinstance(reward_signals[RewardSignalType.CURIOSITY], CuriositySettings)
# Check invalid reward signal type
reward_signals_dict = {"puppo": {"strength": 1.0}}
with pytest.raises(ValueError):
RewardSignalSettings.structure(
reward_signals_dict, Dict[RewardSignalType, RewardSignalSettings]
)
# Check missing GAIL demo path
reward_signals_dict = {"gail": {"strength": 1.0}}
with pytest.raises(TypeError):
RewardSignalSettings.structure(
reward_signals_dict, Dict[RewardSignalType, RewardSignalSettings]
)
# Check non-Dict input
with pytest.raises(TrainerConfigError):
RewardSignalSettings.structure(
"notadict", Dict[RewardSignalType, RewardSignalSettings]
)
def test_parameter_randomization_structure():
"""
Tests the ParameterRandomizationSettings structure method and all validators.
"""
parameter_randomization_dict = {
"mass": {"uniform": {"min_value": 1.0, "max_value": 2.0}},
"scale": {"gaussian": {"mean": 1.0, "st_dev": 2.0}},
"length": {"multirangeuniform": {"intervals": [[1.0, 2.0], [3.0, 4.0]]}},
}
parameter_randomization_distributions = ParameterRandomizationSettings.structure(
parameter_randomization_dict, Dict[str, ParameterRandomizationSettings]
)
assert isinstance(parameter_randomization_distributions["mass"], UniformSettings)
assert isinstance(parameter_randomization_distributions["scale"], GaussianSettings)
assert isinstance(
parameter_randomization_distributions["length"], MultiRangeUniformSettings
)
# Check invalid distribution type
invalid_distribution_dict = {"mass": {"beta": {"alpha": 1.0, "beta": 2.0}}}
with pytest.raises(ValueError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check min less than max in uniform
invalid_distribution_dict = {
"mass": {"uniform": {"min_value": 2.0, "max_value": 1.0}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check min less than max in multirange
invalid_distribution_dict = {
"mass": {"multirangeuniform": {"intervals": [[2.0, 1.0]]}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check multirange has valid intervals
invalid_distribution_dict = {
"mass": {"multirangeuniform": {"intervals": [[1.0, 2.0], [3.0]]}}
}
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
invalid_distribution_dict, Dict[str, ParameterRandomizationSettings]
)
# Check non-Dict input
with pytest.raises(TrainerConfigError):
ParameterRandomizationSettings.structure(
"notadict", Dict[str, ParameterRandomizationSettings]
)