浏览代码
python timers (#2180)
python timers (#2180)
* Timer proof-of-concept * micro optimizations * add some timers * cleanup, add asserts * Cleanup (no start/end methods) and handle exceptions * unit test and decorator * move output code, add a decorator * cleanup * module docstring * actually write the timings when done with training * use __qualname__ instead * add a few more timers * fix mock import * fix unit test * don't need fwd reference * cleanup root * always write timers, add comments * undo accidental change/develop-generalizationTraining-TrainerController
GitHub
5 年前
当前提交
84d9d622
共有 5 个文件被更改,包括 306 次插入 和 7 次删除
-
9ml-agents-envs/mlagents/envs/subprocess_env_manager.py
-
3ml-agents/mlagents/trainers/ppo/policy.py
-
24ml-agents/mlagents/trainers/trainer_controller.py
-
96ml-agents-envs/mlagents/envs/tests/test_timers.py
-
181ml-agents-envs/mlagents/envs/timers.py
|
|||
from unittest import mock |
|||
|
|||
from mlagents.envs import timers |
|||
|
|||
|
|||
@timers.timed |
|||
def decorated_func(x: int = 0, y: float = 1.0) -> str: |
|||
return f"{x} + {y} = {x + y}" |
|||
|
|||
|
|||
def test_timers() -> None: |
|||
with mock.patch( |
|||
"mlagents.envs.timers._global_timer_stack", new_callable=timers.TimerStack |
|||
) as test_timer: |
|||
# First, run some simple code |
|||
with timers.hierarchical_timer("top_level"): |
|||
for i in range(3): |
|||
with timers.hierarchical_timer("multiple"): |
|||
decorated_func() |
|||
|
|||
raised = False |
|||
try: |
|||
with timers.hierarchical_timer("raises"): |
|||
raise RuntimeError("timeout!") |
|||
except RuntimeError: |
|||
raised = True |
|||
|
|||
with timers.hierarchical_timer("post_raise"): |
|||
assert raised |
|||
pass |
|||
|
|||
# We expect the hierarchy to look like |
|||
# (root) |
|||
# top_level |
|||
# multiple |
|||
# decorated_func |
|||
# raises |
|||
# post_raise |
|||
root = test_timer.root |
|||
assert root.children.keys() == {"top_level"} |
|||
|
|||
top_level = root.children["top_level"] |
|||
assert top_level.children.keys() == {"multiple", "raises", "post_raise"} |
|||
|
|||
# make sure the scope was closed properly when the exception was raised |
|||
raises = top_level.children["raises"] |
|||
assert raises.count == 1 |
|||
|
|||
multiple = top_level.children["multiple"] |
|||
assert multiple.count == 3 |
|||
|
|||
timer_tree = test_timer.get_timing_tree() |
|||
|
|||
expected_tree = { |
|||
"name": "root", |
|||
"total": mock.ANY, |
|||
"count": 1, |
|||
"self": mock.ANY, |
|||
"children": [ |
|||
{ |
|||
"name": "top_level", |
|||
"total": mock.ANY, |
|||
"count": 1, |
|||
"self": mock.ANY, |
|||
"children": [ |
|||
{ |
|||
"name": "multiple", |
|||
"total": mock.ANY, |
|||
"count": 3, |
|||
"self": mock.ANY, |
|||
"children": [ |
|||
{ |
|||
"name": "decorated_func", |
|||
"total": mock.ANY, |
|||
"count": 3, |
|||
"self": mock.ANY, |
|||
} |
|||
], |
|||
}, |
|||
{ |
|||
"name": "raises", |
|||
"total": mock.ANY, |
|||
"count": 1, |
|||
"self": mock.ANY, |
|||
}, |
|||
{ |
|||
"name": "post_raise", |
|||
"total": mock.ANY, |
|||
"count": 1, |
|||
"self": mock.ANY, |
|||
}, |
|||
], |
|||
} |
|||
], |
|||
} |
|||
assert timer_tree == expected_tree |
|
|||
# # Unity ML-Agents Toolkit |
|||
from time import perf_counter |
|||
|
|||
from contextlib import contextmanager |
|||
from typing import Any, Callable, Dict, Generator, TypeVar |
|||
|
|||
""" |
|||
Lightweight, hierarchical timers for profiling sections of code. |
|||
|
|||
Example: |
|||
|
|||
@timed |
|||
def foo(t): |
|||
time.sleep(t) |
|||
|
|||
def main(): |
|||
for i in range(3): |
|||
foo(i + 1) |
|||
with hierarchical_timer("context"): |
|||
foo(1) |
|||
|
|||
print(get_timer_tree()) |
|||
|
|||
This would produce a timer tree like |
|||
(root) |
|||
"foo" |
|||
"context" |
|||
"foo" |
|||
|
|||
The total time and counts are tracked for each block of code; in this example "foo" and "context.foo" are considered |
|||
distinct blocks, and are tracked separately. |
|||
|
|||
The decorator and contextmanager are equivalent; the context manager may be more useful if you want more control |
|||
over the timer name, or are splitting up multiple sections of a large function. |
|||
""" |
|||
|
|||
|
|||
class TimerNode: |
|||
""" |
|||
Represents the time spent in a block of code. |
|||
""" |
|||
|
|||
__slots__ = ["children", "total", "count"] |
|||
|
|||
def __init__(self): |
|||
# Note that since dictionary keys are the node names, we don't explicitly store the name on the TimerNode. |
|||
self.children: Dict[str, TimerNode] = {} |
|||
self.total: float = 0.0 |
|||
self.count: int = 0 |
|||
|
|||
def get_child(self, name: str) -> "TimerNode": |
|||
""" |
|||
Get the child node corresponding to the name (and create if it doesn't already exist). |
|||
""" |
|||
child = self.children.get(name) |
|||
if child is None: |
|||
child = TimerNode() |
|||
self.children[name] = child |
|||
return child |
|||
|
|||
def add_time(self, elapsed: float) -> None: |
|||
""" |
|||
Accumulate the time spent in the node (and increment the count). |
|||
""" |
|||
self.total += elapsed |
|||
self.count += 1 |
|||
|
|||
|
|||
class TimerStack: |
|||
""" |
|||
Tracks all the time spent. Users shouldn't use this directly, they should use the contextmanager below to make |
|||
sure that pushes and pops are already matched. |
|||
""" |
|||
|
|||
__slots__ = ["root", "stack", "start_time"] |
|||
|
|||
def __init__(self): |
|||
self.root = TimerNode() |
|||
self.stack = [self.root] |
|||
self.start_time = perf_counter() |
|||
|
|||
def push(self, name: str) -> TimerNode: |
|||
""" |
|||
Called when entering a new block of code that is timed (e.g. with a contextmanager). |
|||
""" |
|||
current_node: TimerNode = self.stack[-1] |
|||
next_node = current_node.get_child(name) |
|||
self.stack.append(next_node) |
|||
return next_node |
|||
|
|||
def pop(self) -> None: |
|||
""" |
|||
Called when exiting a new block of code that is timed (e.g. with a contextmanager). |
|||
""" |
|||
self.stack.pop() |
|||
|
|||
def get_timing_tree(self, node: TimerNode = None) -> Dict[str, Any]: |
|||
""" |
|||
Recursively build a tree of timings, suitable for output/archiving. |
|||
""" |
|||
|
|||
if node is None: |
|||
# Special case the root - total is time since it was created, and count is 1 |
|||
node = self.root |
|||
total_elapsed = perf_counter() - self.start_time |
|||
res = {"name": "root", "total": total_elapsed, "count": 1} |
|||
else: |
|||
res = {"total": node.total, "count": node.count} |
|||
|
|||
child_total = 0.0 |
|||
child_list = [] |
|||
for child_name, child_node in node.children.items(): |
|||
child_res: Dict[str, Any] = { |
|||
"name": child_name, |
|||
**self.get_timing_tree(child_node), |
|||
} |
|||
child_list.append(child_res) |
|||
child_total += child_res["total"] |
|||
|
|||
# "self" time is total time minus all time spent on children |
|||
res["self"] = max(0.0, node.total - child_total) |
|||
if child_list: |
|||
res["children"] = child_list |
|||
|
|||
return res |
|||
|
|||
|
|||
# Global instance of a TimerStack. This is generally all that we need for profiling, but you can potentially |
|||
# create multiple instances and pass them to the contextmanager |
|||
_global_timer_stack = TimerStack() |
|||
|
|||
|
|||
@contextmanager |
|||
def hierarchical_timer(name: str, timer_stack: TimerStack = None) -> Generator: |
|||
""" |
|||
Creates a scoped timer around a block of code. This time spent will automatically be incremented when |
|||
the context manager exits. |
|||
""" |
|||
timer_stack = timer_stack or _global_timer_stack |
|||
timer_node = timer_stack.push(name) |
|||
start_time = perf_counter() |
|||
|
|||
try: |
|||
# The wrapped code block will run here. |
|||
yield |
|||
finally: |
|||
# This will trigger either when the context manager exits, or an exception is raised. |
|||
# We'll accumulate the time, and the exception (if any) gets raised automatically. |
|||
elapsed = perf_counter() - start_time |
|||
timer_node.add_time(elapsed) |
|||
timer_stack.pop() |
|||
|
|||
|
|||
# This is used to ensure the signature of the decorated function is preserved |
|||
# See also https://github.com/python/mypy/issues/3157 |
|||
FuncT = TypeVar("FuncT", bound=Callable[..., Any]) |
|||
|
|||
|
|||
def timed(func: FuncT) -> FuncT: |
|||
""" |
|||
Decorator for timing a function or method. The name of the timer will be the qualified name of the function. |
|||
Usage: |
|||
@timed |
|||
def my_func(x, y): |
|||
return x + y |
|||
Note that because this doesn't take arguments, the global timer stack is always used. |
|||
""" |
|||
|
|||
def wrapped(*args, **kwargs): |
|||
with hierarchical_timer(func.__qualname__): |
|||
return func(*args, **kwargs) |
|||
|
|||
return wrapped # type: ignore |
|||
|
|||
|
|||
def get_timer_tree(timer_stack: TimerStack = None) -> Dict[str, Any]: |
|||
""" |
|||
Return the tree of timings from the TimerStack as a dictionary (or the global stack if none is provided) |
|||
""" |
|||
timer_stack = timer_stack or _global_timer_stack |
|||
return timer_stack.get_timing_tree() |
撰写
预览
正在加载...
取消
保存
Reference in new issue