浏览代码

Better error handling if trainer config doesn't contain "default" section (#3063)

/develop
GitHub 5 年前
当前提交
8ca0d810
共有 4 个文件被更改,包括 103 次插入22 次删除
  1. 8
      ml-agents/mlagents/trainers/exception.py
  2. 8
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  3. 73
      ml-agents/mlagents/trainers/tests/test_trainer_util.py
  4. 36
      ml-agents/mlagents/trainers/trainer_util.py

8
ml-agents/mlagents/trainers/exception.py


pass
class TrainerConfigError(Exception):
"""
Any error related to the configuration of trainers in the ML-Agents Toolkit.
"""
pass
class CurriculumError(TrainerError):
"""
Any error related to training with a curriculum.

8
ml-agents/mlagents/trainers/tests/test_simple_rl.py


pass
PPO_CONFIG = """
default:
PPO_CONFIG = f"""
{BRAIN_NAME}:
trainer: ppo
batch_size: 16
beta: 5.0e-3

gamma: 0.99
"""
SAC_CONFIG = """
default:
SAC_CONFIG = f"""
{BRAIN_NAME}:
trainer: sac
batch_size: 8
buffer_size: 500

73
ml-agents/mlagents/trainers/tests/test_trainer_util.py


from mlagents.trainers.trainer_util import load_config, _load_config
from mlagents.trainers.trainer_metrics import TrainerMetrics
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.brain import BrainParameters
@pytest.fixture

use_curiosity: false
curiosity_strength: 0.0
curiosity_enc_size: 1
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)

BrainParametersMock.return_value.brain_name = "testbrain"
external_brains = {"testbrain": BrainParametersMock()}
with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
trainer_factory = trainer_util.TrainerFactory(
trainer_config=bad_config,
summaries_dir=summaries_dir,

trainers[brain_name] = trainer_factory.generate(brain_parameters)
def test_handles_no_default_section():
"""
Make sure the trainer setup handles a missing "default" in the config.
"""
brain_name = "testbrain"
config = dummy_config()
no_default_config = {brain_name: config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
camera_resolutions=[],
vector_action_space_size=[2],
vector_action_descriptions=[],
vector_action_space_type=0,
)
trainer_factory = trainer_util.TrainerFactory(
trainer_config=no_default_config,
summaries_dir="test_dir",
run_id="testrun",
model_path="model_dir",
keep_checkpoints=1,
train_model=True,
load_model=False,
seed=42,
)
trainer_factory.generate(brain_parameters)
def test_raise_if_no_config_for_brain():
"""
Make sure the trainer setup raises a friendlier exception if both "default" and the brain name
are missing from the config.
"""
brain_name = "testbrain"
config = dummy_config()
bad_config = {"some_other_brain": config["default"]}
brain_parameters = BrainParameters(
brain_name=brain_name,
vector_observation_space_size=1,
camera_resolutions=[],
vector_action_space_size=[2],
vector_action_descriptions=[],
vector_action_space_type=0,
)
trainer_factory = trainer_util.TrainerFactory(
trainer_config=bad_config,
summaries_dir="test_dir",
run_id="testrun",
model_path="model_dir",
keep_checkpoints=1,
train_model=True,
load_model=False,
seed=42,
)
with pytest.raises(TrainerConfigError):
trainer_factory.generate(brain_parameters)
with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
load_config("thisFileDefinitelyDoesNotExist.yaml")

- not
- parse
"""
with pytest.raises(UnityEnvironmentException):
with pytest.raises(TrainerConfigError):
fp = io.StringIO(file_contents)
_load_config(fp)

36
ml-agents/mlagents/trainers/trainer_util.py


from typing import Any, Dict, TextIO
from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.ppo.trainer import PPOTrainer

:param multi_gpu: Whether to use multi-GPU training
:return:
"""
trainer_parameters = trainer_config["default"].copy()
if "default" not in trainer_config and brain_name not in trainer_config:
raise TrainerConfigError(
f'Trainer config must have either a "default" section, or a section for the brain name ({brain_name}). '
"See config/trainer_config.yaml for an example."
)
trainer_parameters = trainer_config.get("default", {}).copy()
trainer_parameters["summary_path"] = "{basedir}/{name}".format(
basedir=summaries_dir, name=str(run_id) + "_" + brain_name
)

trainer_parameters.update(trainer_config[_brain_key])
trainer: Trainer = None # type: ignore # will be set to one of these, or raise
if trainer_parameters["trainer"] == "offline_bc":
if "trainer" not in trainer_parameters:
raise TrainerConfigError(
f'The "trainer" key must be set in your trainer config for brain {brain_name} (or the default brain).'
)
trainer_type = trainer_parameters["trainer"]
if trainer_type == "offline_bc":
elif trainer_parameters["trainer"] == "ppo":
elif trainer_type == "ppo":
trainer = PPOTrainer(
brain_parameters,
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length

run_id,
multi_gpu,
)
elif trainer_parameters["trainer"] == "sac":
elif trainer_type == "sac":
trainer = SACTrainer(
brain_parameters,
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length

run_id,
)
else:
raise UnityEnvironmentException(
"The trainer config contains "
"an unknown trainer type for "
"brain {}".format(brain_name)
raise TrainerConfigError(
f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}'
)
return trainer

with open(config_path) as data_file:
return _load_config(data_file)
except IOError:
raise UnityEnvironmentException(
f"Config file could not be found at {config_path}."
)
raise TrainerConfigError(f"Config file could not be found at {config_path}.")
raise UnityEnvironmentException(
raise TrainerConfigError(
f"There was an error decoding Config file from {config_path}. "
f"Make sure your file is save using UTF-8"
)

try:
return yaml.safe_load(fp)
except yaml.parser.ParserError as e:
raise UnityEnvironmentException(
raise TrainerConfigError(
"Error parsing yaml file. Please check for formatting errors. "
"A tool such as http://www.yamllint.com/ can be helpful with this."
) from e
正在加载...
取消
保存