浏览代码

Merge pull request #3631 from Unity-Technologies/release-0.15.0-fix-stats

make sure top-level timer is closed before writing
/release-0.15.0
GitHub 5 年前
当前提交
188d8589
共有 2 个文件被更改,包括 14 次插入14 次删除
  1. 14
      ml-agents/mlagents/trainers/learn.py
  2. 14
      ml-agents/mlagents/trainers/trainer_controller.py

14
ml-agents/mlagents/trainers/learn.py


from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.timers import hierarchical_timer
from mlagents_envs.timers import hierarchical_timer, get_timer_tree
from mlagents.logging_util import create_logger

tc.start_learning(env_manager)
finally:
env_manager.close()
write_timing_tree(summaries_dir, options.run_id)
def write_timing_tree(summaries_dir: str, run_id: str) -> None:
timing_path = f"{summaries_dir}/{run_id}_timers.json"
try:
with open(timing_path, "w") as f:
json.dump(get_timer_tree(), f, indent=4)
except FileNotFoundError:
logging.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)
def create_sampler_manager(sampler_config, run_seed=None):

14
ml-agents/mlagents/trainers/trainer_controller.py


import os
import sys
import json
import logging
from typing import Dict, Optional, Set
from collections import defaultdict

UnityCommunicationException,
)
from mlagents.trainers.sampler_class import SamplerManager
from mlagents_envs.timers import hierarchical_timer, get_timer_tree, timed
from mlagents_envs.timers import hierarchical_timer, timed
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.trainers.trainer_util import TrainerFactory

"Learning was interrupted. Please wait while the graph is generated."
)
self._save_model()
def _write_timing_tree(self) -> None:
timing_path = f"{self.summaries_dir}/{self.run_id}_timers.json"
try:
with open(timing_path, "w") as f:
json.dump(get_timer_tree(), f, indent=4)
except FileNotFoundError:
self.logger.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)
def _export_graph(self):
"""

pass
if self.train_model:
self._export_graph()
self._write_timing_tree()
def end_trainer_episodes(
self, env: EnvManager, lessons_incremented: Dict[str, bool]

正在加载...
取消
保存