您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
131 行
4.4 KiB
131 行
4.4 KiB
from unittest import mock
|
|
import os
|
|
import pytest
|
|
import tempfile
|
|
import csv
|
|
|
|
from mlagents.trainers.stats import (
|
|
StatsReporter,
|
|
TensorboardWriter,
|
|
CSVWriter,
|
|
StatsSummary,
|
|
)
|
|
|
|
|
|
def test_stat_reporter_add_summary_write():
|
|
# Test add_writer
|
|
StatsReporter.writers.clear()
|
|
mock_writer1 = mock.Mock()
|
|
mock_writer2 = mock.Mock()
|
|
StatsReporter.add_writer(mock_writer1)
|
|
StatsReporter.add_writer(mock_writer2)
|
|
assert len(StatsReporter.writers) == 2
|
|
|
|
# Test add_stats and summaries
|
|
statsreporter1 = StatsReporter("category1")
|
|
statsreporter2 = StatsReporter("category2")
|
|
for i in range(10):
|
|
statsreporter1.add_stat("key1", float(i))
|
|
statsreporter2.add_stat("key2", float(i))
|
|
|
|
statssummary1 = statsreporter1.get_stats_summaries("key1")
|
|
statssummary2 = statsreporter2.get_stats_summaries("key2")
|
|
|
|
assert statssummary1.num == 10
|
|
assert statssummary2.num == 10
|
|
assert statssummary1.mean == 4.5
|
|
assert statssummary2.mean == 4.5
|
|
assert statssummary1.std == pytest.approx(2.9, abs=0.1)
|
|
assert statssummary2.std == pytest.approx(2.9, abs=0.1)
|
|
|
|
# Test write_stats
|
|
step = 10
|
|
statsreporter1.write_stats(step)
|
|
mock_writer1.write_stats.assert_called_once_with(
|
|
"category1", {"key1": statssummary1}, step
|
|
)
|
|
mock_writer2.write_stats.assert_called_once_with(
|
|
"category1", {"key1": statssummary1}, step
|
|
)
|
|
|
|
|
|
def test_stat_reporter_text():
|
|
# Test add_writer
|
|
mock_writer = mock.Mock()
|
|
StatsReporter.writers.clear()
|
|
StatsReporter.add_writer(mock_writer)
|
|
assert len(StatsReporter.writers) == 1
|
|
|
|
statsreporter1 = StatsReporter("category1")
|
|
|
|
# Test write_text
|
|
step = 10
|
|
statsreporter1.write_text("this is a text", step)
|
|
mock_writer.write_text.assert_called_once_with("category1", "this is a text", step)
|
|
|
|
|
|
@mock.patch("mlagents.tf_utils.tf.Summary")
|
|
@mock.patch("mlagents.tf_utils.tf.summary.FileWriter")
|
|
def test_tensorboard_writer(mock_filewriter, mock_summary):
|
|
# Test write_stats
|
|
category = "category1"
|
|
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
|
|
tb_writer = TensorboardWriter(base_dir)
|
|
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
|
|
tb_writer.write_stats("category1", {"key1": statssummary1}, 10)
|
|
|
|
# Test that the filewriter has been created and the directory has been created.
|
|
filewriter_dir = "{basedir}/{category}".format(
|
|
basedir=base_dir, category=category
|
|
)
|
|
assert os.path.exists(filewriter_dir)
|
|
mock_filewriter.assert_called_once_with(filewriter_dir)
|
|
|
|
# Test that the filewriter was written to and the summary was added.
|
|
mock_summary.return_value.value.add.assert_called_once_with(
|
|
tag="key1", simple_value=1.0
|
|
)
|
|
mock_filewriter.return_value.add_summary.assert_called_once_with(
|
|
mock_summary.return_value, 10
|
|
)
|
|
mock_filewriter.return_value.flush.assert_called_once()
|
|
|
|
|
|
def test_csv_writer():
|
|
# Test write_stats
|
|
category = "category1"
|
|
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
|
|
csv_writer = CSVWriter(base_dir, required_fields=["key1", "key2"])
|
|
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
|
|
csv_writer.write_stats("category1", {"key1": statssummary1}, 10)
|
|
|
|
# Test that the filewriter has been created and the directory has been created.
|
|
filewriter_dir = "{basedir}/{category}.csv".format(
|
|
basedir=base_dir, category=category
|
|
)
|
|
# The required keys weren't in the stats
|
|
assert not os.path.exists(filewriter_dir)
|
|
|
|
csv_writer.write_stats(
|
|
"category1", {"key1": statssummary1, "key2": statssummary1}, 10
|
|
)
|
|
csv_writer.write_stats(
|
|
"category1", {"key1": statssummary1, "key2": statssummary1}, 20
|
|
)
|
|
|
|
# The required keys were in the stats
|
|
assert os.path.exists(filewriter_dir)
|
|
|
|
with open(filewriter_dir) as csv_file:
|
|
csv_reader = csv.reader(csv_file, delimiter=",")
|
|
line_count = 0
|
|
for row in csv_reader:
|
|
if line_count == 0:
|
|
assert "key1" in row
|
|
assert "key2" in row
|
|
assert "Steps" in row
|
|
line_count += 1
|
|
else:
|
|
assert len(row) == 3
|
|
line_count += 1
|
|
assert line_count == 3
|