您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
72 行
2.7 KiB
72 行
2.7 KiB
import sys
|
|
from typing import List
|
|
|
|
# importlib.metadata is new in python3.8
|
|
# We use the backport for older python versions.
|
|
if sys.version_info < (3, 8):
|
|
import importlib_metadata
|
|
else:
|
|
import importlib.metadata as importlib_metadata # pylint: disable=E0611
|
|
|
|
from mlagents.trainers.stats import StatsWriter
|
|
|
|
from mlagents_envs import logging_util
|
|
from mlagents.plugins import ML_AGENTS_STATS_WRITER
|
|
from mlagents.trainers.settings import RunOptions
|
|
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter
|
|
|
|
|
|
logger = logging_util.get_logger(__name__)
|
|
|
|
|
|
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
|
|
"""
|
|
The StatsWriters that mlagents-learn always uses:
|
|
* A TensorboardWriter to write information to TensorBoard
|
|
* A GaugeWriter to record our internal stats
|
|
* A ConsoleWriter to output to stdout.
|
|
"""
|
|
checkpoint_settings = run_options.checkpoint_settings
|
|
return [
|
|
TensorboardWriter(
|
|
checkpoint_settings.write_path,
|
|
clear_past_data=not checkpoint_settings.resume,
|
|
hidden_keys=["Is Training", "Step"],
|
|
),
|
|
GaugeWriter(),
|
|
ConsoleWriter(),
|
|
]
|
|
|
|
|
|
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
|
|
"""
|
|
Registers all StatsWriter plugins (including the default one),
|
|
and evaluates them, and returns the list of all the StatsWriter implementations.
|
|
"""
|
|
all_stats_writers: List[StatsWriter] = []
|
|
if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points():
|
|
logger.warning(
|
|
f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. "
|
|
"Uninstalling and reinstalling ml-agents via pip should resolve. "
|
|
"Using default plugins for now."
|
|
)
|
|
return get_default_stats_writers(run_options)
|
|
|
|
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]
|
|
|
|
for entry_point in entry_points:
|
|
|
|
try:
|
|
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
|
|
plugin_func = entry_point.load()
|
|
plugin_stats_writers = plugin_func(run_options)
|
|
logger.debug(
|
|
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
|
|
)
|
|
all_stats_writers += plugin_stats_writers
|
|
except BaseException:
|
|
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
|
|
logger.exception(
|
|
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
|
|
)
|
|
return all_stats_writers
|