浏览代码

[skip ci] adding distributed trainers

/trainer-plugin
Anupam Bhatnagar 4 年前
当前提交
dbd21c95
共有 3 个文件被更改,包括 42 次插入25 次删除
  1. 36
      ml-agents/mlagents/trainers/learn.py
  2. 7
      ml-agents/mlagents/trainers/settings.py
  3. 24
      ml-agents/mlagents/trainers/trainer_util.py

36
ml-agents/mlagents/trainers/learn.py


return all_subclasses
def find_initializer_trainer(paths):
def get_initializer_trainer(paths) -> List:
logger.info(f"Found {len(original_initializers)} initializers and {len(original_trainers)} "
f"trainers.")
print(original_initializers)
print(original_trainers)
# add all plugin paths to system path
print(sys.path)
print(f"plugins available {discovered_plugins}")
logger.info(f"The following plugins are available {discovered_plugins}")
all_initializers = list(get_all_subclasses(Initializer))
all_trainers = set(get_all_subclasses(RLTrainer))
print("Registering new initializer")
print(f"Found {len(all_initializers)} new initializers")
distributed_init = all_initializers[0]()
distributed_init.load()
new_initializers = set(get_all_subclasses(Initializer))
if len(new_initializers) == 1:
# load the initializer
logger.info("Registering new initializer")
distributed_init = list(new_initializers)[0]()
distributed_init.load()
new_trainers = all_trainers - original_trainers
print(f"Found {len(new_trainers)} new trainers")
print(all_trainers)
print(new_trainers)
# construct a list of new trainers
all_trainers = set(get_all_subclasses(RLTrainer))
new_trainers = list(all_trainers - original_trainers)
logger.info(f"Found {len(new_trainers)} new trainers")
return new_trainers
else:
raise ValueError("there should be exactly one initializer passed through plugins option")
def get_version_string() -> str:

options.environment_parameters, run_seed, restore=checkpoint_settings.resume
)
find_initializer_trainer(options.plugins)
new_trainers = get_initializer_trainer(options.plugins)
trainer_factory = TrainerFactory(
options.behaviors,
write_path,

7
ml-agents/mlagents/trainers/settings.py


no_graphics: bool = parser.get_default("no_graphics")
# @attr.s(auto_attribs=True)
# class PluginSettings:
# plugins: Optional[List[str]] = parser.get_default("plugins")
#
@attr.s(auto_attribs=True)
class RunOptions(ExportableSettings):
behaviors: DefaultDict[str, TrainerSettings] = attr.ib(

engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)
# plugin_settings: PluginSettings = attr.ib(factory=PluginSettings)
# These are options that are relevant to the run itself, and not the engine or environment.
# They will be left here.

cattr.register_structure_hook(EnvironmentSettings, strict_to_cls)
# cattr.register_structure_hook(PluginSettings, strict_to_cls)
cattr.register_structure_hook(EngineSettings, strict_to_cls)
cattr.register_structure_hook(CheckpointSettings, strict_to_cls)
cattr.register_structure_hook(

24
ml-agents/mlagents/trainers/trainer_util.py


from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings, TrainerType
from mlagents_distributed.trainers.ppo.trainer import DistributedPPOTrainer
from mlagents_distributed.distributed_trainers.sac.distributed_sac_trainer import \
DistributedSACTrainer
logger = get_logger(__name__)

)
elif trainer_type == TrainerType.SAC:
trainer = SACTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
elif trainer_type == TrainerType.DistributedPPO:
trainer = DistributedPPOTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
elif trainer_type == TrainerType.DistributedSAC:
trainer = DistributedSACTrainer(
brain_name,
min_lesson_length,
trainer_settings,

正在加载...
取消
保存