浏览代码

python timers (#2180)

* Timer proof-of-concept

* micro optimizations

* add some timers

* cleanup, add asserts

* Cleanup (no start/end methods) and handle exceptions

* unit test and decorator

* move output code, add a decorator

* cleanup

* module docstring

* actually write the timings when done with training

* use __qualname__ instead

* add a few more timers

* fix mock import

* fix unit test

* don't need fwd reference

* cleanup root

* always write timers, add comments

* undo accidental change
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
84d9d622
共有 5 个文件被更改,包括 306 次插入7 次删除
  1. 9
      ml-agents-envs/mlagents/envs/subprocess_env_manager.py
  2. 3
      ml-agents/mlagents/trainers/ppo/policy.py
  3. 24
      ml-agents/mlagents/trainers/trainer_controller.py
  4. 96
      ml-agents-envs/mlagents/envs/tests/test_timers.py
  5. 181
      ml-agents-envs/mlagents/envs/timers.py

9
ml-agents-envs/mlagents/envs/subprocess_env_manager.py


from multiprocessing.connection import Connection
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.env_manager import EnvManager, StepInfo
from mlagents.envs.timers import timed, hierarchical_timer
from mlagents.envs import AllBrainInfo, BrainParameters, ActionInfo

env_worker.previous_all_action_info = all_action_info
env_worker.send("step", all_action_info)
step_brain_infos: List[AllBrainInfo] = [
self.env_workers[i].recv().payload for i in range(len(self.env_workers))
]
with hierarchical_timer("recv"):
step_brain_infos: List[AllBrainInfo] = [
self.env_workers[i].recv().payload for i in range(len(self.env_workers))
]
steps = []
for i in range(len(step_brain_infos)):
env_worker = self.env_workers[i]

for env in self.env_workers:
env.send(name, payload)
@timed
def _take_step(self, last_step: StepInfo) -> Dict[str, ActionInfo]:
all_action_info: Dict[str, ActionInfo] = {}
for brain_name, brain_info in last_step.current_all_brain_info.items():

3
ml-agents/mlagents/trainers/ppo/policy.py


import logging
import numpy as np
from mlagents.envs.timers import timed
from mlagents.trainers import BrainInfo, ActionInfo
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.tf_policy import TFPolicy

"update_batch": self.model.update_batch,
}
@timed
def evaluate(self, brain_info):
"""
Evaluates policy for the agent experiences provided.

run_out["random_normal_epsilon"] = epsilon
return run_out
@timed
def update(self, mini_batch, num_sequences):
"""
Updates model using buffer.

24
ml-agents/mlagents/trainers/trainer_controller.py


"""Launches trainers for each External Brains in a Unity Environment."""
import os
import json
import logging
from typing import *

from mlagents.envs.env_manager import StepInfo
from mlagents.envs.subprocess_env_manager import SubprocessEnvManager
from mlagents.envs.exception import UnityEnvironmentException
from mlagents.envs.timers import hierarchical_timer, get_timer_tree, timed
from mlagents.trainers import Trainer, TrainerMetrics
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.bc.offline_trainer import OfflineBCTrainer

for brain_name in self.trainers.keys():
if brain_name in self.trainer_metrics:
self.trainers[brain_name].write_training_metrics()
def _write_timing_tree(self) -> None:
timing_path = f"{self.summaries_dir}/{self.run_id}_timers.json"
try:
with open(timing_path, "w") as f:
json.dump(get_timer_tree(), f, indent=2)
except FileNotFoundError:
self.logger.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)
def _export_graph(self):
"""

if self.train_model:
self._write_training_metrics()
self._export_graph()
self._write_timing_tree()
@timed
def advance(self, env: SubprocessEnvManager) -> int:
if self.meta_curriculum:
# Get the sizes of the reward buffers.

if changed:
self.trainers[brain_name].reward_buffer.clear()
time_start_step = time()
new_step_infos = env.step()
delta_time_step = time() - time_start_step
with hierarchical_timer("env_step"):
time_start_step = time()
new_step_infos = env.step()
delta_time_step = time() - time_start_step
for step_info in new_step_infos:
for brain_name, trainer in self.trainers.items():

trainer.increment_step(len(new_step_infos))
if trainer.is_ready_update():
# Perform gradient descent with experience buffer
trainer.update_policy()
with hierarchical_timer("update_policy"):
trainer.update_policy()
env.set_policy(brain_name, trainer.policy)
return len(new_step_infos)

96
ml-agents-envs/mlagents/envs/tests/test_timers.py


from unittest import mock
from mlagents.envs import timers
@timers.timed
def decorated_func(x: int = 0, y: float = 1.0) -> str:
return f"{x} + {y} = {x + y}"
def test_timers() -> None:
with mock.patch(
"mlagents.envs.timers._global_timer_stack", new_callable=timers.TimerStack
) as test_timer:
# First, run some simple code
with timers.hierarchical_timer("top_level"):
for i in range(3):
with timers.hierarchical_timer("multiple"):
decorated_func()
raised = False
try:
with timers.hierarchical_timer("raises"):
raise RuntimeError("timeout!")
except RuntimeError:
raised = True
with timers.hierarchical_timer("post_raise"):
assert raised
pass
# We expect the hierarchy to look like
# (root)
# top_level
# multiple
# decorated_func
# raises
# post_raise
root = test_timer.root
assert root.children.keys() == {"top_level"}
top_level = root.children["top_level"]
assert top_level.children.keys() == {"multiple", "raises", "post_raise"}
# make sure the scope was closed properly when the exception was raised
raises = top_level.children["raises"]
assert raises.count == 1
multiple = top_level.children["multiple"]
assert multiple.count == 3
timer_tree = test_timer.get_timing_tree()
expected_tree = {
"name": "root",
"total": mock.ANY,
"count": 1,
"self": mock.ANY,
"children": [
{
"name": "top_level",
"total": mock.ANY,
"count": 1,
"self": mock.ANY,
"children": [
{
"name": "multiple",
"total": mock.ANY,
"count": 3,
"self": mock.ANY,
"children": [
{
"name": "decorated_func",
"total": mock.ANY,
"count": 3,
"self": mock.ANY,
}
],
},
{
"name": "raises",
"total": mock.ANY,
"count": 1,
"self": mock.ANY,
},
{
"name": "post_raise",
"total": mock.ANY,
"count": 1,
"self": mock.ANY,
},
],
}
],
}
assert timer_tree == expected_tree

181
ml-agents-envs/mlagents/envs/timers.py


# # Unity ML-Agents Toolkit
from time import perf_counter
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, TypeVar
"""
Lightweight, hierarchical timers for profiling sections of code.
Example:
@timed
def foo(t):
time.sleep(t)
def main():
for i in range(3):
foo(i + 1)
with hierarchical_timer("context"):
foo(1)
print(get_timer_tree())
This would produce a timer tree like
(root)
"foo"
"context"
"foo"
The total time and counts are tracked for each block of code; in this example "foo" and "context.foo" are considered
distinct blocks, and are tracked separately.
The decorator and contextmanager are equivalent; the context manager may be more useful if you want more control
over the timer name, or are splitting up multiple sections of a large function.
"""
class TimerNode:
"""
Represents the time spent in a block of code.
"""
__slots__ = ["children", "total", "count"]
def __init__(self):
# Note that since dictionary keys are the node names, we don't explicitly store the name on the TimerNode.
self.children: Dict[str, TimerNode] = {}
self.total: float = 0.0
self.count: int = 0
def get_child(self, name: str) -> "TimerNode":
"""
Get the child node corresponding to the name (and create if it doesn't already exist).
"""
child = self.children.get(name)
if child is None:
child = TimerNode()
self.children[name] = child
return child
def add_time(self, elapsed: float) -> None:
"""
Accumulate the time spent in the node (and increment the count).
"""
self.total += elapsed
self.count += 1
class TimerStack:
"""
Tracks all the time spent. Users shouldn't use this directly, they should use the contextmanager below to make
sure that pushes and pops are already matched.
"""
__slots__ = ["root", "stack", "start_time"]
def __init__(self):
self.root = TimerNode()
self.stack = [self.root]
self.start_time = perf_counter()
def push(self, name: str) -> TimerNode:
"""
Called when entering a new block of code that is timed (e.g. with a contextmanager).
"""
current_node: TimerNode = self.stack[-1]
next_node = current_node.get_child(name)
self.stack.append(next_node)
return next_node
def pop(self) -> None:
"""
Called when exiting a new block of code that is timed (e.g. with a contextmanager).
"""
self.stack.pop()
def get_timing_tree(self, node: TimerNode = None) -> Dict[str, Any]:
"""
Recursively build a tree of timings, suitable for output/archiving.
"""
if node is None:
# Special case the root - total is time since it was created, and count is 1
node = self.root
total_elapsed = perf_counter() - self.start_time
res = {"name": "root", "total": total_elapsed, "count": 1}
else:
res = {"total": node.total, "count": node.count}
child_total = 0.0
child_list = []
for child_name, child_node in node.children.items():
child_res: Dict[str, Any] = {
"name": child_name,
**self.get_timing_tree(child_node),
}
child_list.append(child_res)
child_total += child_res["total"]
# "self" time is total time minus all time spent on children
res["self"] = max(0.0, node.total - child_total)
if child_list:
res["children"] = child_list
return res
# Global instance of a TimerStack. This is generally all that we need for profiling, but you can potentially
# create multiple instances and pass them to the contextmanager
_global_timer_stack = TimerStack()
@contextmanager
def hierarchical_timer(name: str, timer_stack: TimerStack = None) -> Generator:
"""
Creates a scoped timer around a block of code. This time spent will automatically be incremented when
the context manager exits.
"""
timer_stack = timer_stack or _global_timer_stack
timer_node = timer_stack.push(name)
start_time = perf_counter()
try:
# The wrapped code block will run here.
yield
finally:
# This will trigger either when the context manager exits, or an exception is raised.
# We'll accumulate the time, and the exception (if any) gets raised automatically.
elapsed = perf_counter() - start_time
timer_node.add_time(elapsed)
timer_stack.pop()
# This is used to ensure the signature of the decorated function is preserved
# See also https://github.com/python/mypy/issues/3157
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
def timed(func: FuncT) -> FuncT:
"""
Decorator for timing a function or method. The name of the timer will be the qualified name of the function.
Usage:
@timed
def my_func(x, y):
return x + y
Note that because this doesn't take arguments, the global timer stack is always used.
"""
def wrapped(*args, **kwargs):
with hierarchical_timer(func.__qualname__):
return func(*args, **kwargs)
return wrapped # type: ignore
def get_timer_tree(timer_stack: TimerStack = None) -> Dict[str, Any]:
"""
Return the tree of timings from the TimerStack as a dictionary (or the global stack if none is provided)
"""
timer_stack = timer_stack or _global_timer_stack
return timer_stack.get_timing_tree()
正在加载...
取消
保存