您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
110 行
4.4 KiB
110 行
4.4 KiB
from typing import Any, Dict
|
|
|
|
from mlagents.trainers.meta_curriculum import MetaCurriculum
|
|
from mlagents.envs.exception import UnityEnvironmentException
|
|
from mlagents.trainers.trainer import Trainer
|
|
from mlagents.envs.brain import BrainParameters
|
|
from mlagents.trainers.ppo.trainer import PPOTrainer
|
|
from mlagents.trainers.sac.trainer import SACTrainer
|
|
from mlagents.trainers.bc.offline_trainer import OfflineBCTrainer
|
|
from mlagents.trainers.bc.online_trainer import OnlineBCTrainer
|
|
|
|
|
|
def initialize_trainers(
|
|
trainer_config: Dict[str, Any],
|
|
external_brains: Dict[str, BrainParameters],
|
|
summaries_dir: str,
|
|
run_id: str,
|
|
model_path: str,
|
|
keep_checkpoints: int,
|
|
train_model: bool,
|
|
load_model: bool,
|
|
seed: int,
|
|
meta_curriculum: MetaCurriculum = None,
|
|
multi_gpu: bool = False,
|
|
) -> Dict[str, Trainer]:
|
|
"""
|
|
Initializes trainers given a provided trainer configuration and set of brains from the environment, as well as
|
|
some general training session options.
|
|
|
|
:param trainer_config: Original trainer configuration loaded from YAML
|
|
:param external_brains: BrainParameters provided by the Unity environment
|
|
:param summaries_dir: Directory to store trainer summary statistics
|
|
:param run_id: Run ID to associate with this training run
|
|
:param model_path: Path to save the model
|
|
:param keep_checkpoints: How many model checkpoints to keep
|
|
:param train_model: Whether to train the model (vs. run inference)
|
|
:param load_model: Whether to load the model or randomly initialize
|
|
:param seed: The random seed to use
|
|
:param meta_curriculum: Optional meta_curriculum, used to determine a reward buffer length for PPOTrainer
|
|
:param multi_gpu: Whether to use multi-GPU training
|
|
:return:
|
|
"""
|
|
trainers = {}
|
|
trainer_parameters_dict = {}
|
|
for brain_name in external_brains:
|
|
trainer_parameters = trainer_config["default"].copy()
|
|
trainer_parameters["summary_path"] = "{basedir}/{name}".format(
|
|
basedir=summaries_dir, name=str(run_id) + "_" + brain_name
|
|
)
|
|
trainer_parameters["model_path"] = "{basedir}/{name}".format(
|
|
basedir=model_path, name=brain_name
|
|
)
|
|
trainer_parameters["keep_checkpoints"] = keep_checkpoints
|
|
if brain_name in trainer_config:
|
|
_brain_key: Any = brain_name
|
|
while not isinstance(trainer_config[_brain_key], dict):
|
|
_brain_key = trainer_config[_brain_key]
|
|
trainer_parameters.update(trainer_config[_brain_key])
|
|
trainer_parameters_dict[brain_name] = trainer_parameters.copy()
|
|
for brain_name in external_brains:
|
|
if trainer_parameters_dict[brain_name]["trainer"] == "offline_bc":
|
|
trainers[brain_name] = OfflineBCTrainer(
|
|
external_brains[brain_name],
|
|
trainer_parameters_dict[brain_name],
|
|
train_model,
|
|
load_model,
|
|
seed,
|
|
run_id,
|
|
)
|
|
elif trainer_parameters_dict[brain_name]["trainer"] == "online_bc":
|
|
trainers[brain_name] = OnlineBCTrainer(
|
|
external_brains[brain_name],
|
|
trainer_parameters_dict[brain_name],
|
|
train_model,
|
|
load_model,
|
|
seed,
|
|
run_id,
|
|
)
|
|
elif trainer_parameters_dict[brain_name]["trainer"] == "ppo":
|
|
trainers[brain_name] = PPOTrainer(
|
|
external_brains[brain_name],
|
|
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length
|
|
if meta_curriculum
|
|
else 1,
|
|
trainer_parameters_dict[brain_name],
|
|
train_model,
|
|
load_model,
|
|
seed,
|
|
run_id,
|
|
multi_gpu,
|
|
)
|
|
elif trainer_parameters_dict[brain_name]["trainer"] == "sac":
|
|
trainers[brain_name] = SACTrainer(
|
|
external_brains[brain_name],
|
|
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length
|
|
if meta_curriculum
|
|
else 1,
|
|
trainer_parameters_dict[brain_name],
|
|
train_model,
|
|
load_model,
|
|
seed,
|
|
run_id,
|
|
)
|
|
else:
|
|
raise UnityEnvironmentException(
|
|
"The trainer config contains "
|
|
"an unknown trainer type for "
|
|
"brain {}".format(brain_name)
|
|
)
|
|
return trainers
|