|
|
|
|
|
|
import time |
|
|
|
from threading import RLock |
|
|
|
|
|
|
|
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod |
|
|
|
|
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
from mlagents_envs.timers import set_gauge |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
|
""" |
|
|
|
Takes a parameter dictionary and converts it to a human-readable string. |
|
|
|
Recurses if there are multiple levels of dict. Used to print out hyperparameters. |
|
|
|
param: param_dict: A Dictionary of key, value parameters. |
|
|
|
return: A string version of this dictionary. |
|
|
|
|
|
|
|
:param param_dict: A Dictionary of key, value parameters. |
|
|
|
:return: A string version of this dictionary. |
|
|
|
""" |
|
|
|
if not isinstance(param_dict, dict): |
|
|
|
return str(param_dict) |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class StatsSummary(NamedTuple): |
|
|
|
class StatsSummary(NamedTuple): # pylint: disable=inherit-non-class |
|
|
|
sum: float |
|
|
|
aggregation_method: StatsAggregationMethod |
|
|
|
return StatsSummary(0.0, 0.0, 0) |
|
|
|
return StatsSummary(0.0, 0.0, 0, 0.0, StatsAggregationMethod.AVERAGE) |
|
|
|
|
|
|
|
@property |
|
|
|
def aggregated_value(self): |
|
|
|
if self.aggregation_method == StatsAggregationMethod.SUM: |
|
|
|
return self.sum |
|
|
|
else: |
|
|
|
return self.mean |
|
|
|
|
|
|
|
|
|
|
|
class StatsPropertyType(Enum): |
|
|
|
|
|
|
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters, |
|
|
|
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible |
|
|
|
with all types of properties. For instance, a TB writer doesn't need a max step. |
|
|
|
|
|
|
|
:param type: The type of property. |
|
|
|
:param property_type: The type of property. |
|
|
|
:param value: The property itself. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
GaugeWriter.sanitize_string(f"{category}.{val}.mean"), |
|
|
|
float(stats_summary.mean), |
|
|
|
) |
|
|
|
set_gauge( |
|
|
|
GaugeWriter.sanitize_string(f"{category}.{val}.sum"), |
|
|
|
float(stats_summary.sum), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class ConsoleWriter(StatsWriter): |
|
|
|
|
|
|
is_training = "Not Training" |
|
|
|
if "Is Training" in values: |
|
|
|
stats_summary = values["Is Training"] |
|
|
|
if stats_summary.mean > 0.0: |
|
|
|
if stats_summary.aggregated_value > 0.0: |
|
|
|
is_training = "Training" |
|
|
|
|
|
|
|
elapsed_time = time.time() - self.training_start_time |
|
|
|
|
|
|
def __init__(self, base_dir: str, clear_past_data: bool = False): |
|
|
|
""" |
|
|
|
A StatsWriter that writes to a Tensorboard summary. |
|
|
|
|
|
|
|
category. |
|
|
|
category. |
|
|
|
""" |
|
|
|
self.summary_writers: Dict[str, SummaryWriter] = {} |
|
|
|
self.base_dir: str = base_dir |
|
|
|
|
|
|
) -> None: |
|
|
|
self._maybe_create_summary_writer(category) |
|
|
|
for key, value in values.items(): |
|
|
|
self.summary_writers[category].add_scalar(f"{key}", value.mean, step) |
|
|
|
self.summary_writers[category].add_scalar( |
|
|
|
f"{key}", value.aggregated_value, step |
|
|
|
) |
|
|
|
self.summary_writers[category].flush() |
|
|
|
|
|
|
|
def _maybe_create_summary_writer(self, category: str) -> None: |
|
|
|
|
|
|
writers: List[StatsWriter] = [] |
|
|
|
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) |
|
|
|
lock = RLock() |
|
|
|
stats_aggregation: Dict[str, Dict[str, StatsAggregationMethod]] = defaultdict( |
|
|
|
lambda: defaultdict(lambda: StatsAggregationMethod.AVERAGE) |
|
|
|
) |
|
|
|
|
|
|
|
def __init__(self, category: str): |
|
|
|
""" |
|
|
|
|
|
|
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters, |
|
|
|
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible |
|
|
|
with all types of properties. For instance, a TB writer doesn't need a max step. |
|
|
|
:param key: The type of property. |
|
|
|
|
|
|
|
:param property_type: The type of property. |
|
|
|
:param value: The property itself. |
|
|
|
""" |
|
|
|
with StatsReporter.lock: |
|
|
|
|
|
|
def add_stat(self, key: str, value: float) -> None: |
|
|
|
def add_stat( |
|
|
|
self, |
|
|
|
key: str, |
|
|
|
value: float, |
|
|
|
aggregation: StatsAggregationMethod = StatsAggregationMethod.AVERAGE, |
|
|
|
) -> None: |
|
|
|
|
|
|
|
:param aggregation: the aggregation method for the statistic, default StatsAggregationMethod.AVERAGE. |
|
|
|
StatsReporter.stats_aggregation[self.category][key] = aggregation |
|
|
|
|
|
|
|
StatsReporter.stats_aggregation[self.category][ |
|
|
|
key |
|
|
|
] = StatsAggregationMethod.MOST_RECENT |
|
|
|
|
|
|
|
def write_stats(self, step: int) -> None: |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
:param step: Training step which to write these stats as. |
|
|
|
""" |
|
|
|
with StatsReporter.lock: |
|
|
|
|
|
|
|
|
|
|
def get_stats_summaries(self, key: str) -> StatsSummary: |
|
|
|
""" |
|
|
|
Get the mean, std, and count of a particular statistic, since last write. |
|
|
|
Get the mean, std, count, sum and aggregation method of a particular statistic, since last write. |
|
|
|
|
|
|
|
:returns: A StatsSummary NamedTuple containing (mean, std, count). |
|
|
|
:returns: A StatsSummary containing summary statistics. |
|
|
|
if len(StatsReporter.stats_dict[self.category][key]) > 0: |
|
|
|
return StatsSummary( |
|
|
|
mean=np.mean(StatsReporter.stats_dict[self.category][key]), |
|
|
|
std=np.std(StatsReporter.stats_dict[self.category][key]), |
|
|
|
num=len(StatsReporter.stats_dict[self.category][key]), |
|
|
|
) |
|
|
|
return StatsSummary.empty() |
|
|
|
stat_values = StatsReporter.stats_dict[self.category][key] |
|
|
|
if len(stat_values) == 0: |
|
|
|
return StatsSummary.empty() |
|
|
|
|
|
|
|
return StatsSummary( |
|
|
|
mean=np.mean(stat_values), |
|
|
|
std=np.std(stat_values), |
|
|
|
num=len(stat_values), |
|
|
|
sum=np.sum(stat_values), |
|
|
|
aggregation_method=StatsReporter.stats_aggregation[self.category][key], |
|
|
|
) |