浏览代码

Move console logging to ConsoleWriter

/bug-failed-api-check
Ervin Teng 5 年前
当前提交
bcf073bf
共有 3 个文件被更改,包括 47 次插入28 次删除
  1. 3
      ml-agents/mlagents/trainers/learn.py
  2. 43
      ml-agents/mlagents/trainers/stats.py
  3. 29
      ml-agents/mlagents/trainers/trainer/trainer.py

3
ml-agents/mlagents/trainers/learn.py


CSVWriter,
StatsReporter,
GaugeWriter,
ConsoleWriter,
)
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.sampler_class import SamplerManager

)
tb_writer = TensorboardWriter(summaries_dir)
gauge_write = GaugeWriter()
console_writer = ConsoleWriter()
StatsReporter.add_writer(console_writer)
if options.env_path is None:
port = UnityEnvironment.DEFAULT_EDITOR_PORT

43
ml-agents/mlagents/trainers/stats.py


import abc
import csv
import os
import time
import logging
logger = logging.getLogger("mlagents.trainers")
class StatsSummary(NamedTuple):

) -> None:
for val, stats_summary in values.items():
set_gauge(f"{category}.{val}.mean", float(stats_summary.mean))
def write_text(self, category: str, text: str, step: int) -> None:
pass
class ConsoleWriter(StatsWriter):
def __init__(self):
self.training_start_time = time.time()
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
is_training = "Not Training."
if "Is Training" in values:
stats_summary = stats_summary = values["Is Training"]
if stats_summary.mean > 0.0:
is_training = "Training."
if "Environment/Cumulative Reward" in values:
stats_summary = values["Environment/Cumulative Reward"]
logger.info(
"{}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
category,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
else:
logger.info(
"{}: Step: {}. No episode was completed since last summary. {}".format(
category, step, is_training
)
)
def write_text(self, category: str, text: str, step: int) -> None:
pass

29
ml-agents/mlagents/trainers/trainer/trainer.py


from collections import deque
from mlagents_envs.timers import set_gauge
from mlagents.model_serialization import export_policy_model, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter

"""
Saves training statistics to Tensorboard.
"""
is_training = "Training." if self.should_still_train else "Not Training."
stats_summary = self.stats_reporter.get_stats_summaries(
"Environment/Cumulative Reward"
)
if stats_summary.num > 0:
logger.info(
"{}: {}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
self.run_id,
self.brain_name,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
set_gauge(f"{self.brain_name}.mean_reward", stats_summary.mean)
else:
logger.info(
" {}: {}: Step: {}. No episode was completed since last summary. {}".format(
self.run_id, self.brain_name, step, is_training
)
)
self.stats_reporter.add_stat("Is Training", float(self.should_still_train))
self.stats_reporter.write_stats(int(step))
@abc.abstractmethod

正在加载...
取消
保存