浏览代码

[skip ci] adding type annotations

/trainer-plugin
Anupam Bhatnagar 4 年前
当前提交
07daf8b5
共有 4 个文件被更改,包括 14 次插入8 次删除
  1. 2
      ml-agents/mlagents/trainers/cli_utils.py
  2. 13
      ml-agents/mlagents/trainers/learn.py
  3. 2
      ml-agents/mlagents/trainers/settings.py
  4. 5
      ml-agents/mlagents/trainers/trainer_util.py

2
ml-agents/mlagents/trainers/cli_utils.py


nargs="*",
help="Absolute paths of plugins to be loaded",
required=False,
action=DetectDefault
action=DetectDefault,
)
argparser.add_argument(
"--env-args",

13
ml-agents/mlagents/trainers/learn.py


return all_subclasses
def get_initializer_trainer(paths) -> List:
def get_initializer_trainer(paths: List[str]) -> List:
logger.info(f"Found {len(original_initializers)} initializers and {len(original_trainers)} "
f"trainers.")
logger.info(
f"Found {len(original_initializers)} initializers and {len(original_trainers)} "
f"trainers."
)
# add all plugin paths to system path
for p in paths:

logger.info(f"Found {len(new_trainers)} new trainers")
return new_trainers
else:
raise ValueError("there should be exactly one initializer passed through plugins option")
raise ValueError(
"there should be exactly one initializer passed through plugins option"
)
def get_version_string() -> str:

)
new_trainers = get_initializer_trainer(options.plugins)
print(new_trainers)
trainer_factory = TrainerFactory(
options.behaviors,
write_path,

2
ml-agents/mlagents/trainers/settings.py


# These are options that are relevant to the run itself, and not the engine or environment.
# They will be left here.
debug: bool = parser.get_default("debug")
plugins = parser.get_default("plugins")
plugins: List[str] = parser.get_default("plugins")
# Strict conversion
cattr.register_structure_hook(EnvironmentSettings, strict_to_cls)
cattr.register_structure_hook(EngineSettings, strict_to_cls)

5
ml-agents/mlagents/trainers/trainer_util.py


from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings, TrainerType
from mlagents_distributed.trainers.ppo.trainer import DistributedPPOTrainer
from mlagents_distributed.distributed_trainers.sac.distributed_sac_trainer import \
DistributedSACTrainer
from mlagents_distributed.distributed_trainers.sac.distributed_sac_trainer import (
DistributedSACTrainer,
)
logger = get_logger(__name__)

正在加载...
取消
保存