GitHub
4 年前
当前提交
7954bd26
共有 20 个文件被更改,包括 234 次插入 和 29 次删除
-
1.github/workflows/pytest.yml
-
3com.unity.ml-agents/CHANGELOG.md
-
3ml-agents/mlagents/trainers/cli_utils.py
-
41ml-agents/mlagents/trainers/learn.py
-
18ml-agents/mlagents/trainers/settings.py
-
7ml-agents/mlagents/trainers/stats.py
-
8ml-agents/setup.py
-
58docs/Training-Plugins.md
-
3ml-agents-plugin-examples/README.md
-
0ml-agents-plugin-examples/mlagents_plugin_examples/__init__.py
-
27ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py
-
0ml-agents-plugin-examples/mlagents_plugin_examples/tests/__init__.py
-
13ml-agents-plugin-examples/mlagents_plugin_examples/tests/test_stats_writer_plugin.py
-
17ml-agents-plugin-examples/setup.py
-
1ml-agents/mlagents/plugins/__init__.py
-
63ml-agents/mlagents/plugins/stats_writer.py
|
|||
# Customizing Training via Plugins |
|||
|
|||
ML-Agents provides support for running your own python implementations of specific interfaces during the training |
|||
process. These interfaces are currently fairly limited, but will be expanded in the future. |
|||
|
|||
** Note ** Plugin interfaces should currently be considered "in beta", and they may change in future releases. |
|||
|
|||
## How to Write Your Own Plugin |
|||
[This video](https://www.youtube.com/watch?v=fY3Y_xPKWNA) explains the basics of how to create a plugin system using |
|||
setuptools, and is the same approach that ML-Agents' plugin system is based on. |
|||
|
|||
The `ml-agents-plugin-examples` directory contains a reference implementation of each plugin interface, so it's a good |
|||
starting point. |
|||
|
|||
### setup.py |
|||
If you don't already have a `setup.py` file for your python code, you'll need to add one. `ml-agents-plugin-examples` |
|||
has a [minimal example](../ml-agents-plugin-examples/setup.py) of this. |
|||
|
|||
In the call to `setup()`, you'll need to add to the `entry_points` dictionary for each plugin interface that you |
|||
implement. The form of this is `{entry point name}={plugin module}:{plugin function}`. For example, in |
|||
`ml-agents-plugin-examples`: |
|||
```python |
|||
entry_points={ |
|||
ML_AGENTS_STATS_WRITER: [ |
|||
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" |
|||
] |
|||
} |
|||
``` |
|||
* `ML_AGENTS_STATS_WRITER` (which is a string constant, `mlagents.stats_writer`) is the name of the plugin interface. |
|||
This must be one of the provided interfaces ([see below](#plugin-interfaces)). |
|||
* `example` is the plugin implementation name. This can be anything. |
|||
* `mlagents_plugin_examples.example_stats_writer` is the plugin module. This points to the module where the |
|||
plugin registration function is defined. |
|||
* `get_example_stats_writer` is the plugin registration function. This is called when running `mlagents-learn`. The |
|||
arguments and expected return type for this are different for each plugin interface. |
|||
|
|||
### Local Installation |
|||
Once you've defined `entry_points` in your `setup.py`, you will need to run |
|||
``` |
|||
pip install -e [path to your plugin code] |
|||
``` |
|||
in the same python virtual environment that you have `mlagents` installed. |
|||
|
|||
## Plugin Interfaces |
|||
|
|||
### StatsWriter |
|||
The StatsWriter class receives various information from the training process, such as the average Agent reward in |
|||
each summary period. By default, we log this information to the console and write it to |
|||
[TensorBoard](Using-Tensorboard.md). |
|||
|
|||
#### Interface |
|||
The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter, |
|||
which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with |
|||
string keys. |
|||
|
|||
#### Registration |
|||
The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An |
|||
example implementation is provided in [`mlagents_plugin_examples`](../ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py) |
|
|||
# ML-Agents Plugins |
|||
|
|||
See the [Plugins documentation](../docs/Training-Plugins.md) for more information. |
|
|||
from typing import Dict, List |
|||
from mlagents.trainers.settings import RunOptions |
|||
from mlagents.trainers.stats import StatsWriter, StatsSummary |
|||
|
|||
|
|||
class ExampleStatsWriter(StatsWriter): |
|||
""" |
|||
Example implementation of the StatsWriter abstract class. |
|||
This doesn't do anything interesting, just prints the stats that it gets. |
|||
""" |
|||
|
|||
def write_stats( |
|||
self, category: str, values: Dict[str, StatsSummary], step: int |
|||
) -> None: |
|||
print(f"ExampleStatsWriter category: {category} values: {values}") |
|||
|
|||
|
|||
def get_example_stats_writer(run_options: RunOptions) -> List[StatsWriter]: |
|||
""" |
|||
Registration function. This is referenced in setup.py and will |
|||
be called by mlagents-learn when it starts to determine the |
|||
list of StatsWriters to use. |
|||
|
|||
It must return a list of StatsWriters. |
|||
""" |
|||
print("Creating a new stats writer! This is so exciting!") |
|||
return [ExampleStatsWriter()] |
|
|||
import pytest |
|||
|
|||
from mlagents.plugins.stats_writer import register_stats_writer_plugins |
|||
from mlagents.trainers.settings import RunOptions |
|||
|
|||
from mlagents_plugin_examples.example_stats_writer import ExampleStatsWriter |
|||
|
|||
|
|||
@pytest.mark.check_environment_trains |
|||
def test_register_stats_writers(): |
|||
# Make sure that the ExampleStatsWriter gets returned from the list of all StatsWriters |
|||
stats_writers = register_stats_writer_plugins(RunOptions()) |
|||
assert any(isinstance(sw, ExampleStatsWriter) for sw in stats_writers) |
|
|||
from setuptools import setup |
|||
from mlagents.plugins import ML_AGENTS_STATS_WRITER |
|||
|
|||
setup( |
|||
name="mlagents_plugin_examples", |
|||
version="0.0.1", |
|||
# Example of how to add your own registration functions that will be called |
|||
# by mlagents-learn. |
|||
# |
|||
# Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py |
|||
# will get registered with the ML_AGENTS_STATS_WRITER plugin interface. |
|||
entry_points={ |
|||
ML_AGENTS_STATS_WRITER: [ |
|||
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" |
|||
] |
|||
}, |
|||
) |
|
|||
ML_AGENTS_STATS_WRITER = "mlagents.stats_writer" |
|
|||
import sys |
|||
from typing import List |
|||
|
|||
# importlib.metadata is new in python3.8 |
|||
# We use the backport for older python versions. |
|||
if sys.version_info < (3, 8): |
|||
import importlib_metadata |
|||
else: |
|||
import importlib.metadata as importlib_metadata # pylint: disable=E0611 |
|||
|
|||
from mlagents.trainers.stats import StatsWriter |
|||
|
|||
from mlagents_envs import logging_util |
|||
from mlagents.plugins import ML_AGENTS_STATS_WRITER |
|||
from mlagents.trainers.settings import RunOptions |
|||
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter |
|||
|
|||
|
|||
logger = logging_util.get_logger(__name__) |
|||
|
|||
|
|||
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: |
|||
""" |
|||
The StatsWriters that mlagents-learn always uses: |
|||
* A TensorboardWriter to write information to TensorBoard |
|||
* A GaugeWriter to record our internal stats |
|||
* A ConsoleWriter to output to stdout. |
|||
""" |
|||
checkpoint_settings = run_options.checkpoint_settings |
|||
return [ |
|||
TensorboardWriter( |
|||
checkpoint_settings.write_path, |
|||
clear_past_data=not checkpoint_settings.resume, |
|||
), |
|||
GaugeWriter(), |
|||
ConsoleWriter(), |
|||
] |
|||
|
|||
|
|||
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]: |
|||
""" |
|||
Registers all StatsWriter plugins (including the default one), |
|||
and evaluates them, and returns the list of all the StatsWriter implementations. |
|||
""" |
|||
all_stats_writers: List[StatsWriter] = [] |
|||
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER] |
|||
|
|||
for entry_point in entry_points: |
|||
|
|||
try: |
|||
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}") |
|||
plugin_func = entry_point.load() |
|||
plugin_stats_writers = plugin_func(run_options) |
|||
logger.debug( |
|||
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}" |
|||
) |
|||
all_stats_writers += plugin_stats_writers |
|||
except BaseException: |
|||
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things. |
|||
logger.exception( |
|||
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used." |
|||
) |
|||
return all_stats_writers |
撰写
预览
正在加载...
取消
保存
Reference in new issue