|
|
|
|
|
|
import unittest |
|
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
from mlagents.trainers.stats import ( |
|
|
|
StatsReporter, |
|
|
|
TensorboardWriter, |
|
|
|
|
|
|
StatsPropertyType, |
|
|
|
StatsAggregationMethod, |
|
|
|
) |
|
|
|
|
|
|
|
from mlagents.trainers.env_manager import AgentManager |
|
|
|
|
|
|
|
|
|
|
|
def test_stat_reporter_add_summary_write(): |
|
|
|
|
|
|
"category1", StatsPropertyType.HYPERPARAMETERS, {"example": 1.0} |
|
|
|
) |
|
|
|
assert mock_summary.return_value.add_text.call_count >= 1 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("aggregation_type", list(StatsAggregationMethod)) |
|
|
|
def test_agent_manager_stats_report(aggregation_type): |
|
|
|
stats_reporter = StatsReporter("recorder_name") |
|
|
|
manager = AgentManager(None, "behaviorName", stats_reporter) |
|
|
|
|
|
|
|
values = range(5) |
|
|
|
|
|
|
|
env_stats = {"stat": [(i, aggregation_type) for i in values]} |
|
|
|
manager.record_environment_stats(env_stats, 0) |
|
|
|
summary = stats_reporter.get_stats_summaries("stat") |
|
|
|
aggregation_result = { |
|
|
|
StatsAggregationMethod.AVERAGE: sum(values) / len(values), |
|
|
|
StatsAggregationMethod.MOST_RECENT: values[-1], |
|
|
|
StatsAggregationMethod.SUM: sum(values), |
|
|
|
StatsAggregationMethod.HISTOGRAM: sum(values) / len(values), |
|
|
|
} |
|
|
|
|
|
|
|
assert summary.aggregated_value == aggregation_result[aggregation_type] |
|
|
|
stats_reporter.write_stats(0) |
|
|
|
|
|
|
|
|
|
|
|
def test_tensorboard_writer_clear(tmp_path): |
|
|
|