import sys from typing import List # importlib.metadata is new in python3.8 # We use the backport for older python versions. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata # pylint: disable=E0611 from mlagents.trainers.stats import StatsWriter from mlagents_envs import logging_util from mlagents.plugins import ML_AGENTS_STATS_WRITER from mlagents.trainers.settings import RunOptions from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter logger = logging_util.get_logger(__name__) def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: """ The StatsWriters that mlagents-learn always uses: * A TensorboardWriter to write information to TensorBoard * A GaugeWriter to record our internal stats * A ConsoleWriter to output to stdout. """ checkpoint_settings = run_options.checkpoint_settings return [ TensorboardWriter( checkpoint_settings.write_path, clear_past_data=not checkpoint_settings.resume, hidden_keys=["Is Training", "Step"], ), GaugeWriter(), ConsoleWriter(), ] def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]: """ Registers all StatsWriter plugins (including the default one), and evaluates them, and returns the list of all the StatsWriter implementations. """ all_stats_writers: List[StatsWriter] = [] if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points(): logger.warning( f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. " "Uninstalling and reinstalling ml-agents via pip should resolve. " "Using default plugins for now." ) return get_default_stats_writers(run_options) entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER] for entry_point in entry_points: try: logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}") plugin_func = entry_point.load() plugin_stats_writers = plugin_func(run_options) logger.debug( f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}" ) all_stats_writers += plugin_stats_writers except BaseException: # Catch all exceptions from setting up the plugin, so that bad user code doesn't break things. logger.exception( f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used." ) return all_stats_writers