您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
96 行
3.1 KiB
96 行
3.1 KiB
from unittest import mock
|
|
|
|
from mlagents.envs import timers
|
|
|
|
|
|
@timers.timed
|
|
def decorated_func(x: int = 0, y: float = 1.0) -> str:
|
|
return f"{x} + {y} = {x + y}"
|
|
|
|
|
|
def test_timers() -> None:
|
|
with mock.patch(
|
|
"mlagents.envs.timers._global_timer_stack", new_callable=timers.TimerStack
|
|
) as test_timer:
|
|
# First, run some simple code
|
|
with timers.hierarchical_timer("top_level"):
|
|
for i in range(3):
|
|
with timers.hierarchical_timer("multiple"):
|
|
decorated_func()
|
|
|
|
raised = False
|
|
try:
|
|
with timers.hierarchical_timer("raises"):
|
|
raise RuntimeError("timeout!")
|
|
except RuntimeError:
|
|
raised = True
|
|
|
|
with timers.hierarchical_timer("post_raise"):
|
|
assert raised
|
|
pass
|
|
|
|
# We expect the hierarchy to look like
|
|
# (root)
|
|
# top_level
|
|
# multiple
|
|
# decorated_func
|
|
# raises
|
|
# post_raise
|
|
root = test_timer.root
|
|
assert root.children.keys() == {"top_level"}
|
|
|
|
top_level = root.children["top_level"]
|
|
assert top_level.children.keys() == {"multiple", "raises", "post_raise"}
|
|
|
|
# make sure the scope was closed properly when the exception was raised
|
|
raises = top_level.children["raises"]
|
|
assert raises.count == 1
|
|
|
|
multiple = top_level.children["multiple"]
|
|
assert multiple.count == 3
|
|
|
|
timer_tree = test_timer.get_timing_tree()
|
|
|
|
expected_tree = {
|
|
"name": "root",
|
|
"total": mock.ANY,
|
|
"count": 1,
|
|
"self": mock.ANY,
|
|
"children": [
|
|
{
|
|
"name": "top_level",
|
|
"total": mock.ANY,
|
|
"count": 1,
|
|
"self": mock.ANY,
|
|
"children": [
|
|
{
|
|
"name": "multiple",
|
|
"total": mock.ANY,
|
|
"count": 3,
|
|
"self": mock.ANY,
|
|
"children": [
|
|
{
|
|
"name": "decorated_func",
|
|
"total": mock.ANY,
|
|
"count": 3,
|
|
"self": mock.ANY,
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"name": "raises",
|
|
"total": mock.ANY,
|
|
"count": 1,
|
|
"self": mock.ANY,
|
|
},
|
|
{
|
|
"name": "post_raise",
|
|
"total": mock.ANY,
|
|
"count": 1,
|
|
"self": mock.ANY,
|
|
},
|
|
],
|
|
}
|
|
],
|
|
}
|
|
assert timer_tree == expected_tree
|