# # Unity ML-Agents Toolkit # ## ML-Agent Learning """Launches trainers for each External Brains in a Unity Environment.""" import os import threading from typing import Dict, Set, List from collections import defaultdict import numpy as np from mlagents.tf_utils import tf from mlagents_envs.logging_util import get_logger from mlagents.trainers.env_manager import EnvManager from mlagents_envs.exception import ( UnityEnvironmentException, UnityCommunicationException, UnityCommunicatorStoppedException, ) from mlagents_envs.timers import ( hierarchical_timer, timed, get_timer_stack_for_thread, merge_gauges, ) from mlagents.trainers.trainer import Trainer from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager from mlagents.trainers.trainer_util import TrainerFactory from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers from mlagents.trainers.agent_processor import AgentManager class TrainerController(object): def __init__( self, trainer_factory: TrainerFactory, output_path: str, run_id: str, param_manager: EnvironmentParameterManager, train: bool, training_seed: int, ): """ :param output_path: Path to save the model. :param summaries_dir: Folder to save training summaries. :param run_id: The sub-directory name for model and summary statistics :param param_manager: EnvironmentParameterManager object which stores information about all environment parameters. :param train: Whether to train model, or only run inference. :param training_seed: Seed to use for Numpy and Tensorflow random number generation. :param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging. """ self.trainers: Dict[str, Trainer] = {} self.brain_name_to_identifier: Dict[str, Set] = defaultdict(set) self.trainer_factory = trainer_factory self.output_path = output_path self.logger = get_logger(__name__) self.run_id = run_id self.train_model = train self.param_manager = param_manager self.ghost_controller = self.trainer_factory.ghost_controller self.trainer_threads: List[threading.Thread] = [] self.kill_trainers = False np.random.seed(training_seed) tf.set_random_seed(training_seed) @timed def _save_models(self): """ Saves current model to checkpoint folder. """ for brain_name in self.trainers.keys(): self.trainers[brain_name].save_model() self.logger.info("Saved Model") def _save_model_when_interrupted(self): self.logger.info( "Learning was interrupted. Please wait while the graph is generated." ) self._save_models() def _export_graph(self): """ Saves models for all trainers. """ for brain_name in self.trainers.keys(): self.trainers[brain_name].save_model() @staticmethod def _create_output_path(output_path): try: if not os.path.exists(output_path): os.makedirs(output_path) except Exception: raise UnityEnvironmentException( f"The folder {output_path} containing the " "generated model could not be " "accessed. Please make sure the " "permissions are set correctly." ) @timed def _reset_env(self, env: EnvManager) -> None: """Resets the environment. Returns: A Data structure corresponding to the initial reset state of the environment. """ new_config = self.param_manager.get_current_samplers() env.reset(config=new_config) def _not_done_training(self) -> bool: return ( any(t.should_still_train for t in self.trainers.values()) or not self.train_model ) or len(self.trainers) == 0 def _create_trainer_and_manager( self, env_manager: EnvManager, name_behavior_id: str ) -> None: parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id) brain_name = parsed_behavior_id.brain_name trainerthread = None try: trainer = self.trainers[brain_name] except KeyError: trainer = self.trainer_factory.generate(brain_name) self.trainers[brain_name] = trainer if trainer.threaded: # Only create trainer thread for new trainers trainerthread = threading.Thread( target=self.trainer_update_func, args=(trainer,), daemon=True ) self.trainer_threads.append(trainerthread) policy = trainer.create_policy( parsed_behavior_id, env_manager.training_behaviors[name_behavior_id] ) trainer.add_policy(parsed_behavior_id, policy) agent_manager = AgentManager( policy, name_behavior_id, trainer.stats_reporter, trainer.parameters.time_horizon, threaded=trainer.threaded, ) env_manager.set_agent_manager(name_behavior_id, agent_manager) env_manager.set_policy(name_behavior_id, policy) self.brain_name_to_identifier[brain_name].add(name_behavior_id) trainer.publish_policy_queue(agent_manager.policy_queue) trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue) # Only start new trainers if trainerthread is not None: trainerthread.start() def _create_trainers_and_managers( self, env_manager: EnvManager, behavior_ids: Set[str] ) -> None: for behavior_id in behavior_ids: self._create_trainer_and_manager(env_manager, behavior_id) @timed def start_learning(self, env_manager: EnvManager) -> None: self._create_output_path(self.output_path) tf.reset_default_graph() last_brain_behavior_ids: Set[str] = set() try: # Initial reset self._reset_env(env_manager) while self._not_done_training(): external_brain_behavior_ids = set(env_manager.training_behaviors.keys()) new_behavior_ids = external_brain_behavior_ids - last_brain_behavior_ids self._create_trainers_and_managers(env_manager, new_behavior_ids) last_brain_behavior_ids = external_brain_behavior_ids n_steps = self.advance(env_manager) for _ in range(n_steps): self.reset_env_if_ready(env_manager) # Stop advancing trainers self.join_threads() except ( KeyboardInterrupt, UnityCommunicationException, UnityEnvironmentException, UnityCommunicatorStoppedException, ) as ex: self.join_threads() self.logger.info( "Learning was interrupted. Please wait while the graph is generated." ) if isinstance(ex, KeyboardInterrupt) or isinstance( ex, UnityCommunicatorStoppedException ): pass else: # If the environment failed, we want to make sure to raise # the exception so we exit the process with an return code of 1. raise ex finally: if self.train_model: self._save_models() def end_trainer_episodes(self) -> None: # Reward buffers reset takes place only for curriculum learning # else no reset. for trainer in self.trainers.values(): trainer.end_episode() def reset_env_if_ready(self, env: EnvManager) -> None: # Get the sizes of the reward buffers. reward_buff = {k: list(t.reward_buffer) for (k, t) in self.trainers.items()} curr_step = {k: int(t.step) for (k, t) in self.trainers.items()} max_step = {k: int(t.get_max_steps) for (k, t) in self.trainers.items()} # Attempt to increment the lessons of the brains who # were ready. updated, param_must_reset = self.param_manager.update_lessons( curr_step, max_step, reward_buff ) if updated: for trainer in self.trainers.values(): trainer.reward_buffer.clear() # If ghost trainer swapped teams ghost_controller_reset = self.ghost_controller.should_reset() if param_must_reset or ghost_controller_reset: self._reset_env(env) # This reset also sends the new config to env self.end_trainer_episodes() elif updated: env.set_env_parameters(self.param_manager.get_current_samplers()) @timed def advance(self, env: EnvManager) -> int: # Get steps with hierarchical_timer("env_step"): num_steps = env.advance() # Report current lesson for each environment parameter for ( param_name, lesson_number, ) in self.param_manager.get_current_lesson_number().items(): for trainer in self.trainers.values(): trainer.stats_reporter.set_stat( f"Environment/Lesson/{param_name}", lesson_number ) for trainer in self.trainers.values(): if not trainer.threaded: with hierarchical_timer("trainer_advance"): trainer.advance() return num_steps def join_threads(self, timeout_seconds: float = 1.0) -> None: """ Wait for threads to finish, and merge their timer information into the main thread. :param timeout_seconds: :return: """ self.kill_trainers = True for t in self.trainer_threads: try: t.join(timeout_seconds) except Exception: pass with hierarchical_timer("trainer_threads") as main_timer_node: for trainer_thread in self.trainer_threads: thread_timer_stack = get_timer_stack_for_thread(trainer_thread) if thread_timer_stack: main_timer_node.merge( thread_timer_stack.root, root_name="thread_root", is_parallel=True, ) merge_gauges(thread_timer_stack.gauges) def trainer_update_func(self, trainer: Trainer) -> None: while not self.kill_trainers: with hierarchical_timer("trainer_advance"): trainer.advance()