|
|
|
|
|
|
from mlagents.trainers.policy import Policy |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents_envs.timers import hierarchical_timer |
|
|
|
from mlagents.trainers.progress_bar import ProgressBar |
|
|
|
|
|
|
|
logger = logging.getLogger("mlagents.trainers") |
|
|
|
|
|
|
|
|
|
|
self.training_start_time = time.time() |
|
|
|
self.summary_freq = self.trainer_parameters["summary_freq"] |
|
|
|
self.next_summary_step = self.summary_freq |
|
|
|
self.progress_bar = ProgressBar( |
|
|
|
self.brain_name, "Steps", count=self.get_step, total=self.get_max_steps |
|
|
|
) |
|
|
|
|
|
|
|
def _check_param_keys(self): |
|
|
|
for k in self.param_keys: |
|
|
|
|
|
|
self.run_id, self.brain_name, step, is_training |
|
|
|
) |
|
|
|
) |
|
|
|
self.progress_bar.update(step) |
|
|
|
self.stats_reporter.write_stats(int(step)) |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|