浏览代码

Fix timers when using multithreading. (#3901)

/release_1_branch
GitHub 5 年前
当前提交
f501c395
共有 5 个文件被更改,包括 101 次插入20 次删除
  1. 3
      .pylintrc
  2. 5
      docs/Profiling-Python.md
  3. 5
      ml-agents-envs/mlagents_envs/tests/test_timers.py
  4. 73
      ml-agents-envs/mlagents_envs/timers.py
  5. 35
      ml-agents/mlagents/trainers/trainer_controller.py

3
.pylintrc


# Using the global statement
W0603,
# "Access to a protected member _foo of a client class (protected-access)"
W0212

5
docs/Profiling-Python.md


is optional and defaults to false.
### Parallel execution
#### Subprocesses
For code that executes in multiple processes (for example, SubprocessEnvManager), we periodically send the timer
information back to the "main" process, aggregate the timers there, and flush them in the subprocess. Note that
(depending on the number of processes) this can result in timers where the total time may exceed the parent's total

#### Threads
Timers currently use `time.perf_counter()` to track time spent, which may not give accurate results for multiple
threads. If this is problematic, set `threaded: false` in your trainer configuration.

5
ml-agents-envs/mlagents_envs/tests/test_timers.py


def test_timers() -> None:
with mock.patch(
"mlagents_envs.timers._global_timer_stack", new_callable=timers.TimerStack
) as test_timer:
test_timer = timers.TimerStack()
with mock.patch("mlagents_envs.timers._get_thread_timer", return_value=test_timer):
# First, run some simple code
with timers.hierarchical_timer("top_level"):
for i in range(3):

73
ml-agents-envs/mlagents_envs/timers.py


import math
import sys
import time
import threading
from typing import Any, Callable, Dict, Generator, TypeVar
from typing import Any, Callable, Dict, Generator, Optional, TypeVar
TIMER_FORMAT_VERSION = "0.1.0"

Tracks the most recent value of a metric. This is analogous to gauges in statsd.
"""
__slots__ = ["value", "min_value", "max_value", "count"]
__slots__ = ["value", "min_value", "max_value", "count", "_timestamp"]
def __init__(self, value: float):
self.value = value

# Internal timestamp so we can determine priority.
self._timestamp = time.time()
def update(self, new_value: float) -> None:
self.min_value = min(self.min_value, new_value)

self._timestamp = time.time()
def merge(self, other: "GaugeNode") -> None:
if self._timestamp < other._timestamp:
# Keep the "later" value
self.value = other.value
self._timestamp = other._timestamp
self.min_value = min(self.min_value, other.min_value)
self.max_value = max(self.max_value, other.max_value)
self.count += other.count
def as_dict(self) -> Dict[str, float]:
return {

self.metadata["command_line_arguments"] = " ".join(sys.argv)
# Global instance of a TimerStack. This is generally all that we need for profiling, but you can potentially
# create multiple instances and pass them to the contextmanager
_global_timer_stack = TimerStack()
# Maintain a separate "global" timer per thread, so that they don't accidentally conflict with each other.
_thread_timer_stacks: Dict[int, TimerStack] = {}
def _get_thread_timer() -> TimerStack:
ident = threading.get_ident()
if ident not in _thread_timer_stacks:
timer_stack = TimerStack()
_thread_timer_stacks[ident] = timer_stack
return _thread_timer_stacks[ident]
def get_timer_stack_for_thread(t: threading.Thread) -> Optional[TimerStack]:
if t.ident is None:
# Thread hasn't started, shouldn't ever happen
return None
return _thread_timer_stacks.get(t.ident)
@contextmanager

the context manager exits.
"""
timer_stack = timer_stack or _global_timer_stack
timer_stack = timer_stack or _get_thread_timer()
timer_node = timer_stack.push(name)
start_time = time.perf_counter()

"""
Updates the value of the gauge (or creates it if it hasn't been set before).
"""
timer_stack = timer_stack or _global_timer_stack
timer_stack = timer_stack or _get_thread_timer()
def merge_gauges(gauges: Dict[str, GaugeNode], timer_stack: TimerStack = None) -> None:
"""
Merge the gauges from another TimerStack with the provided one (or the
current thread's stack if none is provided).
:param gauges:
:param timer_stack:
:return:
"""
timer_stack = timer_stack or _get_thread_timer()
for n, g in gauges.items():
if n in timer_stack.gauges:
timer_stack.gauges[n].merge(g)
else:
timer_stack.gauges[n] = g
timer_stack = timer_stack or _global_timer_stack
timer_stack = timer_stack or _get_thread_timer()
Return the tree of timings from the TimerStack as a dictionary (or the global stack if none is provided)
Return the tree of timings from the TimerStack as a dictionary (or the
current thread's stack if none is provided)
timer_stack = timer_stack or _global_timer_stack
timer_stack = timer_stack or _get_thread_timer()
Get the root TimerNode of the timer_stack (or the global TimerStack if not specified)
Get the root TimerNode of the timer_stack (or the current thread's
TimerStack if not specified)
timer_stack = timer_stack or _global_timer_stack
timer_stack = timer_stack or _get_thread_timer()
Reset the timer_stack (or the global TimerStack if not specified)
Reset the timer_stack (or the current thread's TimerStack if not specified)
timer_stack = timer_stack or _global_timer_stack
timer_stack = timer_stack or _get_thread_timer()
timer_stack.reset()

35
ml-agents/mlagents/trainers/trainer_controller.py


UnityCommunicatorStoppedException,
)
from mlagents.trainers.sampler_class import SamplerManager
from mlagents_envs.timers import hierarchical_timer, timed
from mlagents_envs.timers import (
hierarchical_timer,
timed,
get_timer_stack_for_thread,
merge_gauges,
)
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.trainers.trainer_util import TrainerFactory

if self._should_save_model(global_step):
self._save_model()
# Stop advancing trainers
self.kill_trainers = True
self.join_threads()
# Final save Tensorflow model
if global_step != 0 and self.train_model:
self._save_model()

UnityEnvironmentException,
UnityCommunicatorStoppedException,
) as ex:
self.kill_trainers = True
self.join_threads()
if self.train_model:
self._save_model_when_interrupted()

trainer.advance()
return num_steps
def join_threads(self, timeout_seconds: float = 1.0) -> None:
"""
Wait for threads to finish, and merge their timer information into the main thread.
:param timeout_seconds:
:return:
"""
self.kill_trainers = True
for t in self.trainer_threads:
try:
t.join(timeout_seconds)
except Exception:
pass
with hierarchical_timer("trainer_threads") as main_timer_node:
for trainer_thread in self.trainer_threads:
thread_timer_stack = get_timer_stack_for_thread(trainer_thread)
if thread_timer_stack:
main_timer_node.merge(
thread_timer_stack.root,
root_name="thread_root",
is_parallel=True,
)
merge_gauges(thread_timer_stack.gauges)
def trainer_update_func(self, trainer: Trainer) -> None:
while not self.kill_trainers:
正在加载...
取消
保存