|
|
|
|
|
|
# # Unity ML-Agents Toolkit |
|
|
|
import yaml |
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
|
|
|
|
|
import mlagents.trainers |
|
|
|
import mlagents_envs |
|
|
|
from mlagents import tf_utils |
|
|
|
|
|
|
GaugeWriter, |
|
|
|
ConsoleWriter, |
|
|
|
) |
|
|
|
from mlagents.trainers.sac.trainer import SACTrainer |
|
|
|
from mlagents.trainers.ppo.trainer import PPOTrainer |
|
|
|
from mlagents.trainers.cli_utils import parser |
|
|
|
from mlagents_envs.environment import UnityEnvironment |
|
|
|
from mlagents.trainers.settings import RunOptions |
|
|
|
|
|
|
return all_subclasses |
|
|
|
|
|
|
|
|
|
|
|
def get_initializer_trainer(paths: List[str]) -> List: |
|
|
|
def get_initializer_and_trainer(paths: List[str]) -> Optional[dict]: |
|
|
|
original_initializers = set(Initializer.__subclasses__()) |
|
|
|
original_trainers = set(RLTrainer.__subclasses__()) |
|
|
|
logger.info( |
|
|
|
|
|
|
|
|
|
|
new_initializers = set(get_all_subclasses(Initializer)) |
|
|
|
if len(new_initializers) == 0: |
|
|
|
return [] |
|
|
|
return None |
|
|
|
elif len(new_initializers) == 1: |
|
|
|
# load the initializer |
|
|
|
logger.info("Registering new initializer") |
|
|
|
|
|
|
all_trainers = set(get_all_subclasses(RLTrainer)) |
|
|
|
new_trainers = list(all_trainers - original_trainers) |
|
|
|
logger.info(f"Found {len(new_trainers)} new trainers") |
|
|
|
return new_trainers |
|
|
|
|
|
|
|
new_trainer_map = dict() |
|
|
|
for key, value in discovered_plugins.items(): |
|
|
|
trainer_name = importlib.import_module(key) |
|
|
|
|
|
|
|
for name, obj in inspect.getmembers(trainer_name): |
|
|
|
if inspect.isclass(obj) and issubclass(obj, PPOTrainer) and obj != PPOTrainer: |
|
|
|
print(f"Found a sub trainer of PPO Trainer: {obj}") |
|
|
|
new_trainer_map[key] = obj |
|
|
|
if inspect.isclass(obj) and issubclass(obj, SACTrainer) and obj != SACTrainer: |
|
|
|
print(f"Found a sub trainer of SAC Trainer: {obj}") |
|
|
|
new_trainer_map[key] = obj |
|
|
|
return new_trainer_map |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError( |
|
|
|
"There should be exactly one initializer passed through plugins option." |
|
|
|
|
|
|
options.environment_parameters, run_seed, restore=checkpoint_settings.resume |
|
|
|
) |
|
|
|
|
|
|
|
new_trainers = get_initializer_trainer(options.plugins) |
|
|
|
print(new_trainers) |
|
|
|
new_trainer_map = get_initializer_and_trainer(options.plugins) |
|
|
|
trainer_factory = TrainerFactory( |
|
|
|
options.behaviors, |
|
|
|
write_path, |
|
|
|
|
|
|
env_parameter_manager, |
|
|
|
new_trainer_map, |
|
|
|
maybe_init_path, |
|
|
|
False, |
|
|
|
) |
|
|
|