|
|
|
|
|
|
from mlagents.trainers.trainer_util import TrainerFactory |
|
|
|
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|
|
|
from mlagents.trainers.agent_processor import AgentManager |
|
|
|
from mlagents.trainers.ghost.controller import GhostController |
|
|
|
|
|
|
|
|
|
|
|
class TrainerController(object): |
|
|
|
|
|
|
training_seed: int, |
|
|
|
sampler_manager: SamplerManager, |
|
|
|
resampling_interval: Optional[int], |
|
|
|
ghost_controller: GhostController, |
|
|
|
): |
|
|
|
""" |
|
|
|
:param model_path: Path to save the model. |
|
|
|
|
|
|
self.meta_curriculum = meta_curriculum |
|
|
|
self.sampler_manager = sampler_manager |
|
|
|
self.resampling_interval = resampling_interval |
|
|
|
self.ghost_controller = ghost_controller |
|
|
|
|
|
|
|
self.trainer_threads: List[threading.Thread] = [] |
|
|
|
self.kill_trainers = False |
|
|
|
|
|
|
and (self.resampling_interval) |
|
|
|
and (steps % self.resampling_interval == 0) |
|
|
|
) |
|
|
|
if meta_curriculum_reset or generalization_reset: |
|
|
|
if meta_curriculum_reset or generalization_reset or self.ghost_controller.reset: |
|
|
|
self.end_trainer_episodes(env, lessons_incremented) |
|
|
|
|
|
|
|
@timed |
|
|
|