|
|
|
|
|
|
training_seed: int, |
|
|
|
sampler_manager: SamplerManager, |
|
|
|
resampling_interval: Optional[int], |
|
|
|
threaded: bool = True, |
|
|
|
): |
|
|
|
""" |
|
|
|
:param model_path: Path to save the model. |
|
|
|
|
|
|
:param training_seed: Seed to use for Numpy and Tensorflow random number generation. |
|
|
|
:param sampler_manager: SamplerManager object handles samplers for resampling the reset parameters. |
|
|
|
:param resampling_interval: Specifies number of simulation steps after which reset parameters are resampled. |
|
|
|
:param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging. |
|
|
|
""" |
|
|
|
self.trainers: Dict[str, Trainer] = {} |
|
|
|
self.brain_name_to_identifier: Dict[str, Set] = defaultdict(set) |
|
|
|
|
|
|
self.sampler_manager = sampler_manager |
|
|
|
self.resampling_interval = resampling_interval |
|
|
|
|
|
|
|
self.threaded = threaded |
|
|
|
self.trainer_threads: List[threading.Thread] = [] |
|
|
|
self.kill_trainers = False |
|
|
|
np.random.seed(training_seed) |
|
|
|
|
|
|
|
|
|
|
trainer.publish_policy_queue(agent_manager.policy_queue) |
|
|
|
trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue) |
|
|
|
# Start trainer thread |
|
|
|
trainerthread = threading.Thread( |
|
|
|
target=self.trainer_update_func, args=(trainer,), daemon=True |
|
|
|
) |
|
|
|
trainerthread.start() |
|
|
|
self.trainer_threads.append(trainerthread) |
|
|
|
if self.threaded: |
|
|
|
# Start trainer thread |
|
|
|
trainerthread = threading.Thread( |
|
|
|
target=self.trainer_update_func, args=(trainer,), daemon=True |
|
|
|
) |
|
|
|
trainerthread.start() |
|
|
|
self.trainer_threads.append(trainerthread) |
|
|
|
|
|
|
|
def _create_trainers_and_managers( |
|
|
|
self, env_manager: EnvManager, behavior_ids: Set[str] |
|
|
|
|
|
|
self.trainers[brain_name].stats_reporter.set_stat( |
|
|
|
"Environment/Lesson", curr.lesson_num |
|
|
|
) |
|
|
|
|
|
|
|
if not self.threaded: |
|
|
|
with hierarchical_timer("trainer_advance"): |
|
|
|
for trainer in self.trainers.values(): |
|
|
|
trainer.advance(empty_queue=True) |
|
|
|
|
|
|
|
return num_steps |
|
|
|
|
|
|
|