|
|
|
|
|
|
from collections import defaultdict |
|
|
|
from typing import List, Dict, NamedTuple |
|
|
|
from typing import List, Dict, NamedTuple, Any |
|
|
|
import numpy as np |
|
|
|
import abc |
|
|
|
import csv |
|
|
|
|
|
|
|
|
|
|
from mlagents.tf_utils import tf |
|
|
|
from mlagents_envs.timers import set_gauge |
|
|
|
from mlagents.trainers.progress_bar import ProgressBar |
|
|
|
|
|
|
|
logger = logging.getLogger("mlagents.trainers") |
|
|
|
|
|
|
|
|
|
|
and writes it out by some method. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.properties: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
|
|
def set_property(self, category: str, key: str, val: Any) -> None: |
|
|
|
""" |
|
|
|
Sets the max steps for a particular category. Used for tracking training progress. Optional to implement. |
|
|
|
""" |
|
|
|
self.properties[category] = {key: val} |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def write_stats( |
|
|
|
self, category: str, values: Dict[str, StatsSummary], step: int |
|
|
|
|
|
|
|
|
|
|
class ConsoleWriter(StatsWriter): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.training_start_time = time.time() |
|
|
|
|
|
|
|
def write_stats( |
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class ProgressBarWriter(StatsWriter): |
|
|
|
def __init__(self): |
|
|
|
""" |
|
|
|
A StatsWriter that draws a progress bar at the bottom of the console. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.progress_bars: Dict[str, ProgressBar] = {} |
|
|
|
|
|
|
|
def set_property(self, category: str, key: str, val: Any) -> None: |
|
|
|
""" |
|
|
|
Check to see if a max_steps was added. If so, create the progress bar. |
|
|
|
""" |
|
|
|
super().set_property(category, key, val) |
|
|
|
if key == "max_steps": |
|
|
|
self.progress_bars[category] = ProgressBar(category, "steps", 0, val) |
|
|
|
|
|
|
|
def write_stats( |
|
|
|
self, category: str, values: Dict[str, StatsSummary], step: int |
|
|
|
) -> None: |
|
|
|
if category in self.progress_bars: |
|
|
|
self.progress_bars[category].update(step) |
|
|
|
|
|
|
|
def write_text(self, category: str, text: str, step: int) -> None: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class TensorboardWriter(StatsWriter): |
|
|
|
def __init__(self, base_dir: str): |
|
|
|
""" |
|
|
|
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.summary_writers: Dict[str, tf.summary.FileWriter] = {} |
|
|
|
self.base_dir: str = base_dir |
|
|
|
|
|
|
|
|
|
|
:param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for |
|
|
|
them. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
# We need to keep track of the fields in the CSV, as all rows need the same fields. |
|
|
|
self.csv_fields: Dict[str, List[str]] = {} |
|
|
|
self.required_fields = required_fields if required_fields else [] |
|
|
|
|
|
|
writers: List[StatsWriter] = [] |
|
|
|
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) |
|
|
|
|
|
|
|
def __init__(self, category): |
|
|
|
def __init__(self, category: str): |
|
|
|
""" |
|
|
|
Generic StatsReporter. A category is the broadest type of storage (would |
|
|
|
correspond the run name and trainer name, e.g. 3DBalltest_3DBall. A key is the |
|
|
|
|
|
|
self.category: str = category |
|
|
|
|
|
|
|
def set_property(self, key: str, value: Any) -> None: |
|
|
|
""" |
|
|
|
Add a generic property to all available writers. |
|
|
|
:param key: The type of property. |
|
|
|
:param value: The value of the property. |
|
|
|
""" |
|
|
|
for writer in StatsReporter.writers: |
|
|
|
writer.set_property(self.category, key, value) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def add_writer(writer: StatsWriter) -> None: |
|
|
|