浏览代码

Merge branch 'develop-refactorprint' into develop-progress-bar

/develop/progress-bar
Ervin Teng 5 年前
当前提交
6b578de4
共有 4 个文件被更改,包括 79 次插入28 次删除
  1. 3
      ml-agents/mlagents/trainers/learn.py
  2. 43
      ml-agents/mlagents/trainers/stats.py
  3. 32
      ml-agents/mlagents/trainers/tests/test_stats.py
  4. 29
      ml-agents/mlagents/trainers/trainer/trainer.py

3
ml-agents/mlagents/trainers/learn.py


CSVWriter,
StatsReporter,
GaugeWriter,
ConsoleWriter,
)
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.sampler_class import SamplerManager

)
tb_writer = TensorboardWriter(summaries_dir)
gauge_write = GaugeWriter()
console_writer = ConsoleWriter()
StatsReporter.add_writer(console_writer)
if options.env_path is None:
port = UnityEnvironment.DEFAULT_EDITOR_PORT

43
ml-agents/mlagents/trainers/stats.py


import abc
import csv
import os
import time
import logging
logger = logging.getLogger("mlagents.trainers")
class StatsSummary(NamedTuple):

) -> None:
for val, stats_summary in values.items():
set_gauge(f"{category}.{val}.mean", float(stats_summary.mean))
def write_text(self, category: str, text: str, step: int) -> None:
pass
class ConsoleWriter(StatsWriter):
def __init__(self):
self.training_start_time = time.time()
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
is_training = "Not Training."
if "Is Training" in values:
stats_summary = stats_summary = values["Is Training"]
if stats_summary.mean > 0.0:
is_training = "Training."
if "Environment/Cumulative Reward" in values:
stats_summary = values["Environment/Cumulative Reward"]
logger.info(
"{}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
category,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
else:
logger.info(
"{}: Step: {}. No episode was completed since last summary. {}".format(
category, step, is_training
)
)
def write_text(self, category: str, text: str, step: int) -> None:
pass

32
ml-agents/mlagents/trainers/tests/test_stats.py


import os
import pytest
import tempfile
import unittest
import csv
from mlagents.trainers.stats import (

StatsSummary,
GaugeWriter,
ConsoleWriter,
)

GaugeWriter.sanitize_string("Very/Very/Very Nested Stat")
== "Very.Very.VeryNestedStat"
)
class ConsoleWriterTest(unittest.TestCase):
def test_console_writer(self):
# Test write_stats
with self.assertLogs("mlagents.trainers", level="INFO") as cm:
category = "category1"
console_writer = ConsoleWriter()
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
console_writer.write_stats(
category,
{
"Environment/Cumulative Reward": statssummary1,
"Is Training": statssummary1,
},
10,
)
statssummary2 = StatsSummary(mean=0.0, std=0.0, num=1)
console_writer.write_stats(
category,
{
"Environment/Cumulative Reward": statssummary1,
"Is Training": statssummary2,
},
10,
)
self.assertIn(
"Mean Reward: 1.000. Std of Reward: 1.000. Training.", cm.output[0]
)
self.assertIn("Not Training.", cm.output[1])

29
ml-agents/mlagents/trainers/trainer/trainer.py


from collections import deque
from mlagents_envs.timers import set_gauge
from mlagents.model_serialization import export_policy_model, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter

"""
Saves training statistics to Tensorboard.
"""
is_training = "Training." if self.should_still_train else "Not Training."
stats_summary = self.stats_reporter.get_stats_summaries(
"Environment/Cumulative Reward"
)
if stats_summary.num > 0:
logger.info(
"{}: {}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
self.run_id,
self.brain_name,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
set_gauge(f"{self.brain_name}.mean_reward", stats_summary.mean)
else:
logger.info(
" {}: {}: Step: {}. No episode was completed since last summary. {}".format(
self.run_id, self.brain_name, step, is_training
)
)
self.stats_reporter.add_stat("Is Training", float(self.should_still_train))
self.progress_bar.update(step)
self.stats_reporter.write_stats(int(step))

正在加载...
取消
保存