浏览代码

Add "gauges" to timer system (#2329)

* WIP still needs tests and merging from multiprocess

* cleanup gauges

* add TODO for subprocesses
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
83875376
共有 4 个文件被更改,包括 70 次插入3 次删除
  1. 1
      ml-agents-envs/mlagents/envs/subprocess_env_manager.py
  2. 7
      ml-agents-envs/mlagents/envs/tests/test_timers.py
  3. 63
      ml-agents-envs/mlagents/envs/timers.py
  4. 2
      ml-agents/mlagents/trainers/trainer.py

1
ml-agents-envs/mlagents/envs/subprocess_env_manager.py


# So after we send back the root timer, we can safely clear them.
# Note that we could randomly return timers a fraction of the time if we wanted to reduce
# the data transferred.
# TODO get gauges from the workers and merge them in the main process too.
step_response = StepResponse(all_brain_info, get_timer_root())
step_queue.put(EnvironmentResponse("step", worker_id, step_response))
reset_timers()

7
ml-agents-envs/mlagents/envs/tests/test_timers.py


@timers.timed
def decorated_func(x: int = 0, y: float = 1.0) -> str:
timers.set_gauge("my_gauge", x + y)
return f"{x} + {y} = {x + y}"

with timers.hierarchical_timer("top_level"):
for i in range(3):
with timers.hierarchical_timer("multiple"):
decorated_func()
decorated_func(i, i)
raised = False
try:

],
}
],
"gauges": [
{"name": "my_gauge", "value": 4.0, "max": 4.0, "min": 0.0, "count": 3}
],
assert timer_tree == expected_tree

63
ml-agents-envs/mlagents/envs/timers.py


# # Unity ML-Agents Toolkit
import math
from typing import Any, Callable, Dict, Generator, TypeVar
from typing import Any, Callable, Dict, Generator, List, TypeVar
"""
Lightweight, hierarchical timers for profiling sections of code.

child.merge(other_child_node, is_parallel=is_parallel)
class GaugeNode:
"""
Tracks the most recent value of a metric. This is analogous to gauges in statsd.
"""
__slots__ = ["value", "min_value", "max_value", "count"]
def __init__(self, value: float):
self.value = value
self.min_value = value
self.max_value = value
self.count = 1
def update(self, new_value: float):
self.min_value = min(self.min_value, new_value)
self.max_value = max(self.max_value, new_value)
self.value = new_value
self.count += 1
def as_dict(self) -> Dict[str, float]:
return {
"value": self.value,
"min": self.min_value,
"max": self.max_value,
"count": self.count,
}
class TimerStack:
"""
Tracks all the time spent. Users shouldn't use this directly, they should use the contextmanager below to make

__slots__ = ["root", "stack", "start_time"]
__slots__ = ["root", "stack", "start_time", "gauges"]
self.gauges: Dict[str, GaugeNode] = {}
self.gauges: Dict[str, GaugeNode] = {}
def push(self, name: str) -> TimerNode:
"""

node = self.get_root()
res["name"] = "root"
# Only output gauges at top level
if self.gauges:
res["gauges"] = self._get_gauges()
res["total"] = node.total
res["count"] = node.count

return res
def set_gauge(self, name: str, value: float) -> None:
if math.isnan(value):
return
gauge_node = self.gauges.get(name)
if gauge_node:
gauge_node.update(value)
else:
self.gauges[name] = GaugeNode(value)
def _get_gauges(self) -> List[Dict[str, Any]]:
gauges = []
for gauge_name, gauge_node in self.gauges.items():
gauge_dict: Dict[str, Any] = {"name": gauge_name, **gauge_node.as_dict()}
gauges.append(gauge_dict)
return gauges
# Global instance of a TimerStack. This is generally all that we need for profiling, but you can potentially
# create multiple instances and pass them to the contextmanager

return func(*args, **kwargs)
return wrapped # type: ignore
def set_gauge(name: str, value: float, timer_stack: TimerStack = None) -> None:
"""
Updates the value of the gauge (or creates it if it hasn't been set before).
"""
timer_stack = timer_stack or _global_timer_stack
timer_stack.set_gauge(name, value)
def get_timer_tree(timer_stack: TimerStack = None) -> Dict[str, Any]:

2
ml-agents/mlagents/trainers/trainer.py


from collections import deque
from mlagents.envs import UnityException, AllBrainInfo, ActionInfoOutputs
from mlagents.envs.timers import set_gauge
from mlagents.trainers import TrainerMetrics
LOGGER = logging.getLogger("mlagents.trainers")

is_training,
)
)
set_gauge(f"{self.brain_name}.mean_reward", mean_reward)
else:
LOGGER.info(
" {}: {}: Step: {}. No episode was completed since last summary. {}".format(

正在加载...
取消
保存