|
|
|
|
|
|
# # Unity ML-Agents Toolkit |
|
|
|
import math |
|
|
|
from typing import Any, Callable, Dict, Generator, TypeVar |
|
|
|
from typing import Any, Callable, Dict, Generator, List, TypeVar |
|
|
|
|
|
|
|
""" |
|
|
|
Lightweight, hierarchical timers for profiling sections of code. |
|
|
|
|
|
|
child.merge(other_child_node, is_parallel=is_parallel) |
|
|
|
|
|
|
|
|
|
|
|
class GaugeNode: |
|
|
|
""" |
|
|
|
Tracks the most recent value of a metric. This is analogous to gauges in statsd. |
|
|
|
""" |
|
|
|
|
|
|
|
__slots__ = ["value", "min_value", "max_value", "count"] |
|
|
|
|
|
|
|
def __init__(self, value: float): |
|
|
|
self.value = value |
|
|
|
self.min_value = value |
|
|
|
self.max_value = value |
|
|
|
self.count = 1 |
|
|
|
|
|
|
|
def update(self, new_value: float): |
|
|
|
self.min_value = min(self.min_value, new_value) |
|
|
|
self.max_value = max(self.max_value, new_value) |
|
|
|
self.value = new_value |
|
|
|
self.count += 1 |
|
|
|
|
|
|
|
def as_dict(self) -> Dict[str, float]: |
|
|
|
return { |
|
|
|
"value": self.value, |
|
|
|
"min": self.min_value, |
|
|
|
"max": self.max_value, |
|
|
|
"count": self.count, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class TimerStack: |
|
|
|
""" |
|
|
|
Tracks all the time spent. Users shouldn't use this directly, they should use the contextmanager below to make |
|
|
|
|
|
|
__slots__ = ["root", "stack", "start_time"] |
|
|
|
__slots__ = ["root", "stack", "start_time", "gauges"] |
|
|
|
self.gauges: Dict[str, GaugeNode] = {} |
|
|
|
self.gauges: Dict[str, GaugeNode] = {} |
|
|
|
|
|
|
|
def push(self, name: str) -> TimerNode: |
|
|
|
""" |
|
|
|
|
|
|
node = self.get_root() |
|
|
|
res["name"] = "root" |
|
|
|
|
|
|
|
# Only output gauges at top level |
|
|
|
if self.gauges: |
|
|
|
res["gauges"] = self._get_gauges() |
|
|
|
|
|
|
|
res["total"] = node.total |
|
|
|
res["count"] = node.count |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res |
|
|
|
|
|
|
|
def set_gauge(self, name: str, value: float) -> None: |
|
|
|
if math.isnan(value): |
|
|
|
return |
|
|
|
gauge_node = self.gauges.get(name) |
|
|
|
if gauge_node: |
|
|
|
gauge_node.update(value) |
|
|
|
else: |
|
|
|
self.gauges[name] = GaugeNode(value) |
|
|
|
|
|
|
|
def _get_gauges(self) -> List[Dict[str, Any]]: |
|
|
|
gauges = [] |
|
|
|
for gauge_name, gauge_node in self.gauges.items(): |
|
|
|
gauge_dict: Dict[str, Any] = {"name": gauge_name, **gauge_node.as_dict()} |
|
|
|
gauges.append(gauge_dict) |
|
|
|
return gauges |
|
|
|
|
|
|
|
|
|
|
|
# Global instance of a TimerStack. This is generally all that we need for profiling, but you can potentially |
|
|
|
# create multiple instances and pass them to the contextmanager |
|
|
|
|
|
|
return func(*args, **kwargs) |
|
|
|
|
|
|
|
return wrapped # type: ignore |
|
|
|
|
|
|
|
|
|
|
|
def set_gauge(name: str, value: float, timer_stack: TimerStack = None) -> None: |
|
|
|
""" |
|
|
|
Updates the value of the gauge (or creates it if it hasn't been set before). |
|
|
|
""" |
|
|
|
timer_stack = timer_stack or _global_timer_stack |
|
|
|
timer_stack.set_gauge(name, value) |
|
|
|
|
|
|
|
|
|
|
|
def get_timer_tree(timer_stack: TimerStack = None) -> Dict[str, Any]: |
|
|
|