浏览代码

[skip ci] removing package import statements

/trainer-plugin
Anupam Bhatnagar 4 年前
当前提交
d7f0d457
共有 3 个文件被更改,包括 31 次插入15 次删除
  1. 29
      ml-agents/mlagents/trainers/learn.py
  2. 4
      ml-agents/mlagents/trainers/settings.py
  3. 13
      ml-agents/mlagents/trainers/trainer_util.py

29
ml-agents/mlagents/trainers/learn.py


# # Unity ML-Agents Toolkit
import yaml
import inspect
import mlagents.trainers
import mlagents_envs
from mlagents import tf_utils

GaugeWriter,
ConsoleWriter,
)
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.cli_utils import parser
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.settings import RunOptions

return all_subclasses
def get_initializer_trainer(paths: List[str]) -> List:
def get_initializer_and_trainer(paths: List[str]) -> Optional[dict]:
original_initializers = set(Initializer.__subclasses__())
original_trainers = set(RLTrainer.__subclasses__())
logger.info(

new_initializers = set(get_all_subclasses(Initializer))
if len(new_initializers) == 0:
return []
return None
elif len(new_initializers) == 1:
# load the initializer
logger.info("Registering new initializer")

all_trainers = set(get_all_subclasses(RLTrainer))
new_trainers = list(all_trainers - original_trainers)
logger.info(f"Found {len(new_trainers)} new trainers")
return new_trainers
new_trainer_map = dict()
for key, value in discovered_plugins.items():
trainer_name = importlib.import_module(key)
for name, obj in inspect.getmembers(trainer_name):
if inspect.isclass(obj) and issubclass(obj, PPOTrainer) and obj != PPOTrainer:
print(f"Found a sub trainer of PPO Trainer: {obj}")
new_trainer_map[key] = obj
if inspect.isclass(obj) and issubclass(obj, SACTrainer) and obj != SACTrainer:
print(f"Found a sub trainer of SAC Trainer: {obj}")
new_trainer_map[key] = obj
return new_trainer_map
else:
raise ValueError(
"There should be exactly one initializer passed through plugins option."

options.environment_parameters, run_seed, restore=checkpoint_settings.resume
)
new_trainers = get_initializer_trainer(options.plugins)
print(new_trainers)
new_trainer_map = get_initializer_and_trainer(options.plugins)
trainer_factory = TrainerFactory(
options.behaviors,
write_path,

env_parameter_manager,
new_trainer_map,
maybe_init_path,
False,
)

4
ml-agents/mlagents/trainers/settings.py


class TrainerType(Enum):
PPO: str = "ppo"
SAC: str = "sac"
DistributedPPO: str = "distributed-ppo"
DistributedSAC: str = "distributed-sac"
DistributedPPO: str = "distributed_ppo"
DistributedSAC: str = "distributed_sac"
def to_settings(self) -> type:
_mapping = {

13
ml-agents/mlagents/trainers/trainer_util.py


from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings, TrainerType
from mlagents_distributed.trainers.ppo.ppo_trainer import DistributedPPOTrainer
from mlagents_distributed.trainers.sac.sac_trainer import DistributedSACTrainer
logger = get_logger(__name__)

load_model: bool,
seed: int,
param_manager: EnvironmentParameterManager,
new_trainer_map: dict,
self.init_path = init_path
self.new_trainer_map = new_trainer_map
self.init_path = init_path
self.multi_gpu = multi_gpu
self.ghost_controller = GhostController()

self.ghost_controller,
self.seed,
self.param_manager,
self.new_trainer_map,
self.init_path,
self.multi_gpu,
)

ghost_controller: GhostController,
seed: int,
param_manager: EnvironmentParameterManager,
new_trainer_map: dict,
init_path: str = None,
multi_gpu: bool = False,
) -> Trainer:

:param ghost_controller: The object that coordinates ghost trainers
:param seed: The random seed to use
:param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer
:param new_trainer_map: a mapping from trainer name to trainer class; to be used with the plugin
:param init_path: Path from which to load model, if different from model_path.
:return:
"""

trainer_artifact_path,
)
elif trainer_type == TrainerType.DistributedPPO:
trainer = DistributedPPOTrainer(
trainer = new_trainer_map[TrainerType.DistributedPPO](
brain_name,
min_lesson_length,
trainer_settings,

trainer_artifact_path,
)
elif trainer_type == TrainerType.DistributedSAC:
trainer = DistributedSACTrainer(
trainer = new_trainer_map[TrainerType.DistributedSAC](
brain_name,
min_lesson_length,
trainer_settings,

正在加载...
取消
保存