GitHub
5 年前
当前提交
45010af3
共有 18 个文件被更改,包括 306 次插入 和 104 次删除
-
40ml-agents/mlagents/trainers/agent_processor.py
-
13ml-agents/mlagents/trainers/curriculum.py
-
6ml-agents/mlagents/trainers/learn.py
-
2ml-agents/mlagents/trainers/ppo/policy.py
-
8ml-agents/mlagents/trainers/ppo/trainer.py
-
10ml-agents/mlagents/trainers/rl_trainer.py
-
2ml-agents/mlagents/trainers/sac/policy.py
-
10ml-agents/mlagents/trainers/sac/trainer.py
-
9ml-agents/mlagents/trainers/tests/test_agent_processor.py
-
2ml-agents/mlagents/trainers/tests/test_ppo.py
-
2ml-agents/mlagents/trainers/tests/test_sac.py
-
3ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
4ml-agents/mlagents/trainers/tests/test_trainer_util.py
-
85ml-agents/mlagents/trainers/trainer.py
-
15ml-agents/mlagents/trainers/trainer_controller.py
-
4ml-agents/mlagents/trainers/trainer_util.py
-
119ml-agents/mlagents/trainers/stats.py
-
76ml-agents/mlagents/trainers/tests/test_stats.py
|
|||
from collections import defaultdict |
|||
from typing import List, Dict, NamedTuple |
|||
import numpy as np |
|||
import abc |
|||
import os |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
|
|||
class StatsWriter(abc.ABC): |
|||
""" |
|||
A StatsWriter abstract class. A StatsWriter takes in a category, key, scalar value, and step |
|||
and writes it out by some method. |
|||
""" |
|||
|
|||
@abc.abstractmethod |
|||
def write_stats(self, category: str, key: str, value: float, step: int) -> None: |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def write_text(self, category: str, text: str, step: int) -> None: |
|||
pass |
|||
|
|||
|
|||
class TensorboardWriter(StatsWriter): |
|||
def __init__(self, base_dir: str): |
|||
self.summary_writers: Dict[str, tf.summary.FileWriter] = {} |
|||
self.base_dir: str = base_dir |
|||
|
|||
def write_stats(self, category: str, key: str, value: float, step: int) -> None: |
|||
self._maybe_create_summary_writer(category) |
|||
summary = tf.Summary() |
|||
summary.value.add(tag="{}".format(key), simple_value=value) |
|||
self.summary_writers[category].add_summary(summary, step) |
|||
self.summary_writers[category].flush() |
|||
|
|||
def _maybe_create_summary_writer(self, category: str) -> None: |
|||
if category not in self.summary_writers: |
|||
filewriter_dir = "{basedir}/{category}".format( |
|||
basedir=self.base_dir, category=category |
|||
) |
|||
if not os.path.exists(filewriter_dir): |
|||
os.makedirs(filewriter_dir) |
|||
self.summary_writers[category] = tf.summary.FileWriter(filewriter_dir) |
|||
|
|||
def write_text(self, category: str, text: str, step: int) -> None: |
|||
self._maybe_create_summary_writer(category) |
|||
self.summary_writers[category].add_summary(text, step) |
|||
|
|||
|
|||
class StatsSummary(NamedTuple): |
|||
mean: float |
|||
std: float |
|||
num: int |
|||
|
|||
|
|||
class StatsReporter: |
|||
writers: List[StatsWriter] = [] |
|||
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) |
|||
|
|||
def __init__(self, category): |
|||
""" |
|||
Generic StatsReporter. A category is the broadest type of storage (would |
|||
correspond the run name and trainer name, e.g. 3DBalltest_3DBall. A key is the |
|||
type of stat it is (e.g. Environment/Reward). Finally the Value is the float value |
|||
attached to this stat. |
|||
""" |
|||
self.category: str = category |
|||
|
|||
@staticmethod |
|||
def add_writer(writer: StatsWriter) -> None: |
|||
StatsReporter.writers.append(writer) |
|||
|
|||
def add_stat(self, key: str, value: float) -> None: |
|||
""" |
|||
Add a float value stat to the StatsReporter. |
|||
:param category: The highest categorization of the statistic, e.g. behavior name. |
|||
:param key: The type of statistic, e.g. Environment/Reward. |
|||
:param value: the value of the statistic. |
|||
""" |
|||
StatsReporter.stats_dict[self.category][key].append(value) |
|||
|
|||
def write_stats(self, step: int) -> None: |
|||
""" |
|||
Write out all stored statistics that fall under the category specified. |
|||
The currently stored values will be averaged, written out as a single value, |
|||
and the buffer cleared. |
|||
:param category: The category which to write out the stats. |
|||
:param step: Training step which to write these stats as. |
|||
""" |
|||
for key in StatsReporter.stats_dict[self.category]: |
|||
if len(StatsReporter.stats_dict[self.category][key]) > 0: |
|||
stat_mean = float(np.mean(StatsReporter.stats_dict[self.category][key])) |
|||
for writer in StatsReporter.writers: |
|||
writer.write_stats(self.category, key, stat_mean, step) |
|||
del StatsReporter.stats_dict[self.category] |
|||
|
|||
def write_text(self, text: str, step: int) -> None: |
|||
""" |
|||
Write out some text. |
|||
:param category: The highest categorization of the statistic, e.g. behavior name. |
|||
:param text: The text to write out. |
|||
:param step: Training step which to write these stats as. |
|||
""" |
|||
for writer in StatsReporter.writers: |
|||
writer.write_text(self.category, text, step) |
|||
|
|||
def get_stats_summaries(self, key: str) -> StatsSummary: |
|||
""" |
|||
Get the mean, std, and count of a particular statistic, since last write. |
|||
:param category: The highest categorization of the statistic, e.g. behavior name. |
|||
:param key: The type of statistic, e.g. Environment/Reward. |
|||
:returns: A StatsSummary NamedTuple containing (mean, std, count). |
|||
""" |
|||
return StatsSummary( |
|||
mean=np.mean(StatsReporter.stats_dict[self.category][key]), |
|||
std=np.std(StatsReporter.stats_dict[self.category][key]), |
|||
num=len(StatsReporter.stats_dict[self.category][key]), |
|||
) |
|
|||
import unittest.mock as mock |
|||
import os |
|||
|
|||
from mlagents.trainers.stats import StatsReporter, TensorboardWriter |
|||
|
|||
|
|||
def test_stat_reporter_add_summary_write(): |
|||
# Test add_writer |
|||
StatsReporter.writers.clear() |
|||
mock_writer1 = mock.Mock() |
|||
mock_writer2 = mock.Mock() |
|||
StatsReporter.add_writer(mock_writer1) |
|||
StatsReporter.add_writer(mock_writer2) |
|||
assert len(StatsReporter.writers) == 2 |
|||
|
|||
# Test add_stats and summaries |
|||
statsreporter1 = StatsReporter("category1") |
|||
statsreporter2 = StatsReporter("category2") |
|||
for i in range(10): |
|||
statsreporter1.add_stat("key1", float(i)) |
|||
statsreporter2.add_stat("key2", float(i)) |
|||
|
|||
statssummary1 = statsreporter1.get_stats_summaries("key1") |
|||
statssummary2 = statsreporter2.get_stats_summaries("key2") |
|||
|
|||
assert statssummary1.num == 10 |
|||
assert statssummary2.num == 10 |
|||
assert statssummary1.mean == 4.5 |
|||
assert statssummary2.mean == 4.5 |
|||
assert round(statssummary1.std, 1) == 2.9 |
|||
assert round(statssummary2.std, 1) == 2.9 |
|||
|
|||
# Test write_stats |
|||
step = 10 |
|||
statsreporter1.write_stats(step) |
|||
mock_writer1.write_stats.assert_called_once_with("category1", "key1", 4.5, step) |
|||
mock_writer2.write_stats.assert_called_once_with("category1", "key1", 4.5, step) |
|||
|
|||
|
|||
def test_stat_reporter_text(): |
|||
# Test add_writer |
|||
mock_writer = mock.Mock() |
|||
StatsReporter.writers.clear() |
|||
StatsReporter.add_writer(mock_writer) |
|||
assert len(StatsReporter.writers) == 1 |
|||
|
|||
statsreporter1 = StatsReporter("category1") |
|||
|
|||
# Test write_text |
|||
step = 10 |
|||
statsreporter1.write_text("this is a text", step) |
|||
mock_writer.write_text.assert_called_once_with("category1", "this is a text", step) |
|||
|
|||
|
|||
@mock.patch("mlagents.tf_utils.tf.Summary") |
|||
@mock.patch("mlagents.tf_utils.tf.summary.FileWriter") |
|||
def test_tensorboard_writer(mock_filewriter, mock_summary): |
|||
# Test write_stats |
|||
base_dir = "base_dir" |
|||
category = "category1" |
|||
tb_writer = TensorboardWriter(base_dir) |
|||
tb_writer.write_stats("category1", "key1", 1.0, 10) |
|||
|
|||
# Test that the filewriter has been created and the directory has been created. |
|||
filewriter_dir = "{basedir}/{category}".format(basedir=base_dir, category=category) |
|||
assert os.path.exists(filewriter_dir) |
|||
mock_filewriter.assert_called_once_with(filewriter_dir) |
|||
|
|||
# Test that the filewriter was written to and the summary was added. |
|||
mock_summary.return_value.value.add.assert_called_once_with( |
|||
tag="key1", simple_value=1.0 |
|||
) |
|||
mock_filewriter.return_value.add_summary.assert_called_once_with( |
|||
mock_summary.return_value, 10 |
|||
) |
|||
mock_filewriter.return_value.flush.assert_called_once() |
撰写
预览
正在加载...
取消
保存
Reference in new issue