|
|
|
|
|
|
from mlagents.tf_utils import tf |
|
|
|
|
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
from mlagents.trainers.env_manager import EnvManager |
|
|
|
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep |
|
|
|
from mlagents_envs.exception import ( |
|
|
|
UnityEnvironmentException, |
|
|
|
UnityCommunicationException, |
|
|
|
|
|
|
self.train_model = train |
|
|
|
self.param_manager = param_manager |
|
|
|
self.ghost_controller = self.trainer_factory.ghost_controller |
|
|
|
self.registered_behavior_ids: Set[str] = set() |
|
|
|
|
|
|
|
self.trainer_threads: List[threading.Thread] = [] |
|
|
|
self.kill_trainers = False |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
@timed |
|
|
|
def _reset_env(self, env: EnvManager) -> None: |
|
|
|
def _reset_env(self, env_manager: EnvManager) -> None: |
|
|
|
"""Resets the environment. |
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
|
new_config = self.param_manager.get_current_samplers() |
|
|
|
env.reset(config=new_config) |
|
|
|
env_manager.reset(config=new_config) |
|
|
|
# Register any new behavior ids that were generated on the reset. |
|
|
|
self._register_new_behaviors(env_manager, env_manager.first_step_infos) |
|
|
|
|
|
|
|
def _not_done_training(self) -> bool: |
|
|
|
return ( |
|
|
|
|
|
|
def start_learning(self, env_manager: EnvManager) -> None: |
|
|
|
self._create_output_path(self.output_path) |
|
|
|
tf.reset_default_graph() |
|
|
|
last_brain_behavior_ids: Set[str] = set() |
|
|
|
external_brain_behavior_ids = set(env_manager.training_behaviors.keys()) |
|
|
|
new_behavior_ids = external_brain_behavior_ids - last_brain_behavior_ids |
|
|
|
self._create_trainers_and_managers(env_manager, new_behavior_ids) |
|
|
|
last_brain_behavior_ids = external_brain_behavior_ids |
|
|
|
n_steps = self.advance(env_manager) |
|
|
|
for _ in range(n_steps): |
|
|
|
self.reset_env_if_ready(env_manager) |
|
|
|
|
|
|
env.set_env_parameters(self.param_manager.get_current_samplers()) |
|
|
|
|
|
|
|
@timed |
|
|
|
def advance(self, env: EnvManager) -> int: |
|
|
|
def advance(self, env_manager: EnvManager) -> int: |
|
|
|
num_steps = env.advance() |
|
|
|
new_step_infos = env_manager.get_steps() |
|
|
|
self._register_new_behaviors(env_manager, new_step_infos) |
|
|
|
num_steps = env_manager.process_steps(new_step_infos) |
|
|
|
|
|
|
|
# Report current lesson for each environment parameter |
|
|
|
for ( |
|
|
|
|
|
|
trainer.advance() |
|
|
|
|
|
|
|
return num_steps |
|
|
|
|
|
|
|
def _register_new_behaviors( |
|
|
|
self, env_manager: EnvManager, step_infos: List[EnvironmentStep] |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Handle registration (adding trainers and managers) of new behaviors ids. |
|
|
|
:param env_manager: |
|
|
|
:param step_infos: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
step_behavior_ids: Set[str] = set() |
|
|
|
for s in step_infos: |
|
|
|
step_behavior_ids |= set(s.name_behavior_ids) |
|
|
|
new_behavior_ids = step_behavior_ids - self.registered_behavior_ids |
|
|
|
self._create_trainers_and_managers(env_manager, new_behavior_ids) |
|
|
|
self.registered_behavior_ids |= step_behavior_ids |
|
|
|
|
|
|
|
def join_threads(self, timeout_seconds: float = 1.0) -> None: |
|
|
|
""" |
|
|
|