|
|
|
|
|
|
from typing import Any, Dict, TextIO |
|
|
|
|
|
|
|
from mlagents.trainers.meta_curriculum import MetaCurriculum |
|
|
|
from mlagents.envs.exception import UnityEnvironmentException |
|
|
|
from mlagents.trainers.exception import TrainerConfigError |
|
|
|
from mlagents.trainers.trainer import Trainer, UnityTrainerException |
|
|
|
from mlagents.trainers.brain import BrainParameters |
|
|
|
from mlagents.trainers.ppo.trainer import PPOTrainer |
|
|
|
|
|
|
:param multi_gpu: Whether to use multi-GPU training |
|
|
|
:return: |
|
|
|
""" |
|
|
|
trainer_parameters = trainer_config["default"].copy() |
|
|
|
if "default" not in trainer_config and brain_name not in trainer_config: |
|
|
|
raise TrainerConfigError( |
|
|
|
f'Trainer config must have either a "default" section, or a section for the brain name ({brain_name}). ' |
|
|
|
"See config/trainer_config.yaml for an example." |
|
|
|
) |
|
|
|
|
|
|
|
trainer_parameters = trainer_config.get("default", {}).copy() |
|
|
|
trainer_parameters["summary_path"] = "{basedir}/{name}".format( |
|
|
|
basedir=summaries_dir, name=str(run_id) + "_" + brain_name |
|
|
|
) |
|
|
|
|
|
|
trainer_parameters.update(trainer_config[_brain_key]) |
|
|
|
|
|
|
|
trainer: Trainer = None # type: ignore # will be set to one of these, or raise |
|
|
|
if trainer_parameters["trainer"] == "offline_bc": |
|
|
|
if "trainer" not in trainer_parameters: |
|
|
|
raise TrainerConfigError( |
|
|
|
f'The "trainer" key must be set in your trainer config for brain {brain_name} (or the default brain).' |
|
|
|
) |
|
|
|
trainer_type = trainer_parameters["trainer"] |
|
|
|
|
|
|
|
if trainer_type == "offline_bc": |
|
|
|
elif trainer_parameters["trainer"] == "ppo": |
|
|
|
elif trainer_type == "ppo": |
|
|
|
trainer = PPOTrainer( |
|
|
|
brain_parameters, |
|
|
|
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length |
|
|
|
|
|
|
run_id, |
|
|
|
multi_gpu, |
|
|
|
) |
|
|
|
elif trainer_parameters["trainer"] == "sac": |
|
|
|
elif trainer_type == "sac": |
|
|
|
trainer = SACTrainer( |
|
|
|
brain_parameters, |
|
|
|
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length |
|
|
|
|
|
|
run_id, |
|
|
|
) |
|
|
|
else: |
|
|
|
raise UnityEnvironmentException( |
|
|
|
"The trainer config contains " |
|
|
|
"an unknown trainer type for " |
|
|
|
"brain {}".format(brain_name) |
|
|
|
raise TrainerConfigError( |
|
|
|
f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}' |
|
|
|
) |
|
|
|
return trainer |
|
|
|
|
|
|
|
|
|
|
with open(config_path) as data_file: |
|
|
|
return _load_config(data_file) |
|
|
|
except IOError: |
|
|
|
raise UnityEnvironmentException( |
|
|
|
f"Config file could not be found at {config_path}." |
|
|
|
) |
|
|
|
raise TrainerConfigError(f"Config file could not be found at {config_path}.") |
|
|
|
raise UnityEnvironmentException( |
|
|
|
raise TrainerConfigError( |
|
|
|
f"There was an error decoding Config file from {config_path}. " |
|
|
|
f"Make sure your file is save using UTF-8" |
|
|
|
) |
|
|
|
|
|
|
try: |
|
|
|
return yaml.safe_load(fp) |
|
|
|
except yaml.parser.ParserError as e: |
|
|
|
raise UnityEnvironmentException( |
|
|
|
raise TrainerConfigError( |
|
|
|
"Error parsing yaml file. Please check for formatting errors. " |
|
|
|
"A tool such as http://www.yamllint.com/ can be helpful with this." |
|
|
|
) from e |