浏览代码

Extend StatsWriter to allow handling of individual stat updates (#5249)

* Extend StatsWriter to allow callback handling of individual stat updates

* Update documentation and expand test coverage.
/check-for-ModelOverriders
GitHub 4 年前
当前提交
fd79d92c
共有 3 个文件被更改,包括 35 次插入1 次删除
  1. 3
      docs/Training-Plugins.md
  2. 24
      ml-agents/mlagents/trainers/stats.py
  3. 9
      ml-agents/mlagents/trainers/tests/test_stats.py

3
docs/Training-Plugins.md


#### Interface
The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter,
which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with
string keys.
string keys. Additionally, `StatsWriter.on_add_stat()` may be extended to register a callback handler for each stat
emission.
#### Registration
The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An

24
ml-agents/mlagents/trainers/stats.py


and writes it out by some method.
"""
def on_add_stat(
self,
category: str,
key: str,
value: float,
aggregation: StatsAggregationMethod = StatsAggregationMethod.AVERAGE,
) -> None:
"""
Callback method for handling an individual stat value as reported to the StatsReporter add_stat
or set_stat methods.
:param category: Category of the statistics. Usually this is the behavior name.
:param key: The type of statistic, e.g. Environment/Reward.
:param value: The value of the statistic.
:param aggregation: The aggregation method for the statistic, default StatsAggregationMethod.AVERAGE.
"""
pass
@abc.abstractmethod
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

with StatsReporter.lock:
StatsReporter.stats_dict[self.category][key].append(value)
StatsReporter.stats_aggregation[self.category][key] = aggregation
for writer in StatsReporter.writers:
writer.on_add_stat(self.category, key, value, aggregation)
def set_stat(self, key: str, value: float) -> None:
"""

StatsReporter.stats_aggregation[self.category][
key
] = StatsAggregationMethod.MOST_RECENT
for writer in StatsReporter.writers:
writer.on_add_stat(
self.category, key, value, StatsAggregationMethod.MOST_RECENT
)
def write_stats(self, step: int) -> None:
"""

9
ml-agents/mlagents/trainers/tests/test_stats.py


statsreporter1.add_stat("key1", float(i))
statsreporter2.add_stat("key2", float(i))
statsreportercalls = [
mock.call(f"category{j}", f"key{j}", float(i), StatsAggregationMethod.AVERAGE)
for i in range(10)
for j in [1, 2]
]
mock_writer1.on_add_stat.assert_has_calls(statsreportercalls)
mock_writer2.on_add_stat.assert_has_calls(statsreportercalls)
statssummary1 = statsreporter1.get_stats_summaries("key1")
statssummary2 = statsreporter2.get_stats_summaries("key2")

正在加载...
取消
保存