|
|
|
|
|
|
import os |
|
|
|
import sys |
|
|
|
import logging |
|
|
|
from typing import Dict, Optional, Set |
|
|
|
import threading |
|
|
|
from typing import Dict, Optional, Set, List |
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
self.meta_curriculum = meta_curriculum |
|
|
|
self.sampler_manager = sampler_manager |
|
|
|
self.resampling_interval = resampling_interval |
|
|
|
|
|
|
|
self.trainer_threads: List[threading.Thread] = [] |
|
|
|
self.kill_trainers = False |
|
|
|
np.random.seed(training_seed) |
|
|
|
tf.set_random_seed(training_seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.publish_policy_queue(agent_manager.policy_queue) |
|
|
|
trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue) |
|
|
|
# Start trainer thread |
|
|
|
trainerthread = threading.Thread( |
|
|
|
target=self.trainer_update_func, args=(trainer,), daemon=True |
|
|
|
) |
|
|
|
trainerthread.start() |
|
|
|
self.trainer_threads.append(trainerthread) |
|
|
|
|
|
|
|
def _create_trainers_and_managers( |
|
|
|
self, env_manager: EnvManager, behavior_ids: Set[str] |
|
|
|
|
|
|
if global_step != 0 and self.train_model: |
|
|
|
self._save_model() |
|
|
|
except (KeyboardInterrupt, UnityCommunicationException): |
|
|
|
self.kill_trainers = True |
|
|
|
if self.train_model: |
|
|
|
self._save_model_when_interrupted() |
|
|
|
pass |
|
|
|
|
|
|
"Environment/Lesson", curr.lesson_num |
|
|
|
) |
|
|
|
|
|
|
|
# Advance trainers. This can be done in a separate loop in the future. |
|
|
|
with hierarchical_timer("trainer_advance"): |
|
|
|
for trainer in self.trainers.values(): |
|
|
|
trainer.advance() |
|
|
|
# # Advance trainers. This can be done in a separate loop in the future. |
|
|
|
# with hierarchical_timer("trainer_advance"): |
|
|
|
# for trainer in self.trainers.values(): |
|
|
|
# trainer.advance() |
|
|
|
|
|
|
|
def trainer_update_func(self, trainer: Trainer) -> None: |
|
|
|
while not self.kill_trainers: |
|
|
|
with hierarchical_timer("trainer_advance"): |
|
|
|
trainer.advance() |