您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
236 行
7.4 KiB
236 行
7.4 KiB
# # Unity ML-Agents Toolkit
|
|
from time import perf_counter
|
|
|
|
from contextlib import contextmanager
|
|
from typing import Any, Callable, Dict, Generator, TypeVar
|
|
|
|
"""
|
|
Lightweight, hierarchical timers for profiling sections of code.
|
|
|
|
Example:
|
|
|
|
@timed
|
|
def foo(t):
|
|
time.sleep(t)
|
|
|
|
def main():
|
|
for i in range(3):
|
|
foo(i + 1)
|
|
with hierarchical_timer("context"):
|
|
foo(1)
|
|
|
|
print(get_timer_tree())
|
|
|
|
This would produce a timer tree like
|
|
(root)
|
|
"foo"
|
|
"context"
|
|
"foo"
|
|
|
|
The total time and counts are tracked for each block of code; in this example "foo" and "context.foo" are considered
|
|
distinct blocks, and are tracked separately.
|
|
|
|
The decorator and contextmanager are equivalent; the context manager may be more useful if you want more control
|
|
over the timer name, or are splitting up multiple sections of a large function.
|
|
"""
|
|
|
|
|
|
class TimerNode:
|
|
"""
|
|
Represents the time spent in a block of code.
|
|
"""
|
|
|
|
__slots__ = ["children", "total", "count", "is_parallel"]
|
|
|
|
def __init__(self):
|
|
# Note that since dictionary keys are the node names, we don't explicitly store the name on the TimerNode.
|
|
self.children: Dict[str, TimerNode] = {}
|
|
self.total: float = 0.0
|
|
self.count: int = 0
|
|
self.is_parallel = False
|
|
|
|
def get_child(self, name: str) -> "TimerNode":
|
|
"""
|
|
Get the child node corresponding to the name (and create if it doesn't already exist).
|
|
"""
|
|
child = self.children.get(name)
|
|
if child is None:
|
|
child = TimerNode()
|
|
self.children[name] = child
|
|
return child
|
|
|
|
def add_time(self, elapsed: float) -> None:
|
|
"""
|
|
Accumulate the time spent in the node (and increment the count).
|
|
"""
|
|
self.total += elapsed
|
|
self.count += 1
|
|
|
|
def merge(self, other: "TimerNode", root_name: str = None, is_parallel=True):
|
|
"""
|
|
Add the other node to this node, then do the same recursively on its children.
|
|
:param other: The other node to merge
|
|
:param root_name: Optional name of the root node being merged.
|
|
:param is_parallel: Whether or not the code block was executed in parallel.
|
|
:return:
|
|
"""
|
|
if root_name:
|
|
node = self.get_child(root_name)
|
|
else:
|
|
node = self
|
|
|
|
node.total += other.total
|
|
node.count += other.count
|
|
node.is_parallel |= is_parallel
|
|
for other_child_name, other_child_node in other.children.items():
|
|
child = node.get_child(other_child_name)
|
|
child.merge(other_child_node, is_parallel=is_parallel)
|
|
|
|
|
|
class TimerStack:
|
|
"""
|
|
Tracks all the time spent. Users shouldn't use this directly, they should use the contextmanager below to make
|
|
sure that pushes and pops are already matched.
|
|
"""
|
|
|
|
__slots__ = ["root", "stack", "start_time"]
|
|
|
|
def __init__(self):
|
|
self.root = TimerNode()
|
|
self.stack = [self.root]
|
|
self.start_time = perf_counter()
|
|
|
|
def reset(self):
|
|
self.root = TimerNode()
|
|
self.stack = [self.root]
|
|
self.start_time = perf_counter()
|
|
|
|
def push(self, name: str) -> TimerNode:
|
|
"""
|
|
Called when entering a new block of code that is timed (e.g. with a contextmanager).
|
|
"""
|
|
current_node: TimerNode = self.stack[-1]
|
|
next_node = current_node.get_child(name)
|
|
self.stack.append(next_node)
|
|
return next_node
|
|
|
|
def pop(self) -> None:
|
|
"""
|
|
Called when exiting a new block of code that is timed (e.g. with a contextmanager).
|
|
"""
|
|
self.stack.pop()
|
|
|
|
def get_root(self) -> TimerNode:
|
|
"""
|
|
Update the total time and count of the root name, and return it.
|
|
"""
|
|
root = self.root
|
|
root.total = perf_counter() - self.start_time
|
|
root.count = 1
|
|
return root
|
|
|
|
def get_timing_tree(self, node: TimerNode = None) -> Dict[str, Any]:
|
|
"""
|
|
Recursively build a tree of timings, suitable for output/archiving.
|
|
"""
|
|
res: Dict[str, Any] = {}
|
|
if node is None:
|
|
# Special case the root - total is time since it was created, and count is 1
|
|
node = self.get_root()
|
|
res["name"] = "root"
|
|
|
|
res["total"] = node.total
|
|
res["count"] = node.count
|
|
|
|
if node.is_parallel:
|
|
# Note when the block ran in parallel, so that it's less confusing that a timer is less that its children.
|
|
res["is_parallel"] = True
|
|
|
|
child_total = 0.0
|
|
child_list = []
|
|
for child_name, child_node in node.children.items():
|
|
child_res: Dict[str, Any] = {
|
|
"name": child_name,
|
|
**self.get_timing_tree(child_node),
|
|
}
|
|
child_list.append(child_res)
|
|
child_total += child_res["total"]
|
|
|
|
# "self" time is total time minus all time spent on children
|
|
res["self"] = max(0.0, node.total - child_total)
|
|
if child_list:
|
|
res["children"] = child_list
|
|
|
|
return res
|
|
|
|
|
|
# Global instance of a TimerStack. This is generally all that we need for profiling, but you can potentially
|
|
# create multiple instances and pass them to the contextmanager
|
|
_global_timer_stack = TimerStack()
|
|
|
|
|
|
@contextmanager
|
|
def hierarchical_timer(name: str, timer_stack: TimerStack = None) -> Generator:
|
|
"""
|
|
Creates a scoped timer around a block of code. This time spent will automatically be incremented when
|
|
the context manager exits.
|
|
"""
|
|
timer_stack = timer_stack or _global_timer_stack
|
|
timer_node = timer_stack.push(name)
|
|
start_time = perf_counter()
|
|
|
|
try:
|
|
# The wrapped code block will run here.
|
|
yield timer_node
|
|
finally:
|
|
# This will trigger either when the context manager exits, or an exception is raised.
|
|
# We'll accumulate the time, and the exception (if any) gets raised automatically.
|
|
elapsed = perf_counter() - start_time
|
|
timer_node.add_time(elapsed)
|
|
timer_stack.pop()
|
|
|
|
|
|
# This is used to ensure the signature of the decorated function is preserved
|
|
# See also https://github.com/python/mypy/issues/3157
|
|
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
|
|
|
|
|
|
def timed(func: FuncT) -> FuncT:
|
|
"""
|
|
Decorator for timing a function or method. The name of the timer will be the qualified name of the function.
|
|
Usage:
|
|
@timed
|
|
def my_func(x, y):
|
|
return x + y
|
|
Note that because this doesn't take arguments, the global timer stack is always used.
|
|
"""
|
|
|
|
def wrapped(*args, **kwargs):
|
|
with hierarchical_timer(func.__qualname__):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapped # type: ignore
|
|
|
|
|
|
def get_timer_tree(timer_stack: TimerStack = None) -> Dict[str, Any]:
|
|
"""
|
|
Return the tree of timings from the TimerStack as a dictionary (or the global stack if none is provided)
|
|
"""
|
|
timer_stack = timer_stack or _global_timer_stack
|
|
return timer_stack.get_timing_tree()
|
|
|
|
|
|
def get_timer_root(timer_stack: TimerStack = None) -> TimerNode:
|
|
"""
|
|
Get the root TimerNode of the timer_stack (or the global TimerStack if not specified)
|
|
"""
|
|
timer_stack = timer_stack or _global_timer_stack
|
|
return timer_stack.get_root()
|
|
|
|
|
|
def reset_timers(timer_stack: TimerStack = None) -> None:
|
|
"""
|
|
Reset the timer_stack (or the global TimerStack if not specified)
|
|
"""
|
|
timer_stack = timer_stack or _global_timer_stack
|
|
timer_stack.reset()
|