|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import sys |
|
|
|
import json |
|
|
|
import logging |
|
|
|
from typing import Dict, Optional, Set |
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
UnityCommunicationException, |
|
|
|
) |
|
|
|
from mlagents.trainers.sampler_class import SamplerManager |
|
|
|
from mlagents_envs.timers import hierarchical_timer, get_timer_tree, timed |
|
|
|
from mlagents_envs.timers import hierarchical_timer, timed |
|
|
|
from mlagents.trainers.trainer import Trainer |
|
|
|
from mlagents.trainers.meta_curriculum import MetaCurriculum |
|
|
|
from mlagents.trainers.trainer_util import TrainerFactory |
|
|
|
|
|
|
"Learning was interrupted. Please wait while the graph is generated." |
|
|
|
) |
|
|
|
self._save_model() |
|
|
|
|
|
|
|
def _write_timing_tree(self) -> None: |
|
|
|
timing_path = f"{self.summaries_dir}/{self.run_id}_timers.json" |
|
|
|
try: |
|
|
|
with open(timing_path, "w") as f: |
|
|
|
json.dump(get_timer_tree(), f, indent=4) |
|
|
|
except FileNotFoundError: |
|
|
|
self.logger.warning( |
|
|
|
f"Unable to save to {timing_path}. Make sure the directory exists" |
|
|
|
) |
|
|
|
|
|
|
|
def _export_graph(self): |
|
|
|
""" |
|
|
|
|
|
|
pass |
|
|
|
if self.train_model: |
|
|
|
self._export_graph() |
|
|
|
self._write_timing_tree() |
|
|
|
|
|
|
|
def end_trainer_episodes( |
|
|
|
self, env: EnvManager, lessons_incremented: Dict[str, bool] |
|
|
|