|
|
|
|
|
|
from typing import List, Dict, NamedTuple, Any, Optional |
|
|
|
import numpy as np |
|
|
|
import abc |
|
|
|
import csv |
|
|
|
import os |
|
|
|
import time |
|
|
|
from threading import RLock |
|
|
|
|
|
|
""" |
|
|
|
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters, |
|
|
|
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible |
|
|
|
with all types of properties. For instance, a TB writer doesn't need a max step, nor should |
|
|
|
we write hyperparameters to the CSV. |
|
|
|
with all types of properties. For instance, a TB writer doesn't need a max step. |
|
|
|
:param category: The category that the property belongs to. |
|
|
|
:param type: The type of property. |
|
|
|
:param value: The property itself. |
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
class CSVWriter(StatsWriter): |
|
|
|
def __init__(self, base_dir: str, required_fields: List[str] = None): |
|
|
|
""" |
|
|
|
A StatsWriter that writes to a Tensorboard summary. |
|
|
|
:param base_dir: The directory within which to place the CSV file, which will be {base_dir}/{category}.csv. |
|
|
|
:param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for |
|
|
|
them. |
|
|
|
""" |
|
|
|
# We need to keep track of the fields in the CSV, as all rows need the same fields. |
|
|
|
self.csv_fields: Dict[str, List[str]] = {} |
|
|
|
self.required_fields = required_fields if required_fields else [] |
|
|
|
self.base_dir: str = base_dir |
|
|
|
|
|
|
|
def write_stats( |
|
|
|
self, category: str, values: Dict[str, StatsSummary], step: int |
|
|
|
) -> None: |
|
|
|
if self._maybe_create_csv_file(category, list(values.keys())): |
|
|
|
row = [str(step)] |
|
|
|
# Only record the stats that showed up in the first valid row |
|
|
|
for key in self.csv_fields[category]: |
|
|
|
_val = values.get(key, None) |
|
|
|
row.append(str(_val.mean) if _val else "None") |
|
|
|
with open(self._get_filepath(category), "a") as file: |
|
|
|
writer = csv.writer(file) |
|
|
|
writer.writerow(row) |
|
|
|
|
|
|
|
def _maybe_create_csv_file(self, category: str, keys: List[str]) -> bool: |
|
|
|
""" |
|
|
|
If no CSV file exists and the keys have the required values, |
|
|
|
make the CSV file and write hte title row. |
|
|
|
Returns True if there is now (or already is) a valid CSV file. |
|
|
|
""" |
|
|
|
if category not in self.csv_fields: |
|
|
|
summary_dir = self.base_dir |
|
|
|
os.makedirs(summary_dir, exist_ok=True) |
|
|
|
# Only store if the row contains the required fields |
|
|
|
if all(item in keys for item in self.required_fields): |
|
|
|
self.csv_fields[category] = keys |
|
|
|
with open(self._get_filepath(category), "w") as file: |
|
|
|
title_row = ["Steps"] |
|
|
|
title_row.extend(keys) |
|
|
|
writer = csv.writer(file) |
|
|
|
writer.writerow(title_row) |
|
|
|
return True |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def _get_filepath(self, category: str) -> str: |
|
|
|
file_dir = os.path.join(self.base_dir, category + ".csv") |
|
|
|
return file_dir |
|
|
|
|
|
|
|
|
|
|
|
class StatsReporter: |
|
|
|
writers: List[StatsWriter] = [] |
|
|
|
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) |
|
|
|
|
|
|
""" |
|
|
|
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters, |
|
|
|
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible |
|
|
|
with all types of properties. For instance, a TB writer doesn't need a max step, nor should |
|
|
|
we write hyperparameters to the CSV. |
|
|
|
with all types of properties. For instance, a TB writer doesn't need a max step. |
|
|
|
:param key: The type of property. |
|
|
|
:param value: The property itself. |
|
|
|
""" |
|
|
|