浏览代码

Make progress bar a statswriter

/develop/progress-bar
Ervin Teng 5 年前
当前提交
49df4038
共有 3 个文件被更改,包括 56 次插入7 次删除
  1. 3
      ml-agents/mlagents/trainers/learn.py
  2. 52
      ml-agents/mlagents/trainers/stats.py
  3. 8
      ml-agents/mlagents/trainers/trainer/trainer.py

3
ml-agents/mlagents/trainers/learn.py


StatsReporter,
GaugeWriter,
ConsoleWriter,
ProgressBarWriter,
)
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.sampler_class import SamplerManager

tb_writer = TensorboardWriter(summaries_dir)
gauge_write = GaugeWriter()
console_writer = ConsoleWriter()
progress_bar_writer = ProgressBarWriter()
StatsReporter.add_writer(progress_bar_writer)
if options.env_path is None:
port = UnityEnvironment.DEFAULT_EDITOR_PORT

52
ml-agents/mlagents/trainers/stats.py


from collections import defaultdict
from typing import List, Dict, NamedTuple
from typing import List, Dict, NamedTuple, Any
import numpy as np
import abc
import csv

from mlagents.tf_utils import tf
from mlagents_envs.timers import set_gauge
from mlagents.trainers.progress_bar import ProgressBar
logger = logging.getLogger("mlagents.trainers")

and writes it out by some method.
"""
def __init__(self):
self.properties: Dict[str, Dict[str, Any]] = {}
def set_property(self, category: str, key: str, val: Any) -> None:
"""
Sets the max steps for a particular category. Used for tracking training progress. Optional to implement.
"""
self.properties[category] = {key: val}
@abc.abstractmethod
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

class ConsoleWriter(StatsWriter):
def __init__(self):
super().__init__()
self.training_start_time = time.time()
def write_stats(

pass
class ProgressBarWriter(StatsWriter):
def __init__(self):
"""
A StatsWriter that draws a progress bar at the bottom of the console.
"""
super().__init__()
self.progress_bars: Dict[str, ProgressBar] = {}
def set_property(self, category: str, key: str, val: Any) -> None:
"""
Check to see if a max_steps was added. If so, create the progress bar.
"""
super().set_property(category, key, val)
if key == "max_steps":
self.progress_bars[category] = ProgressBar(category, "steps", 0, val)
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
if category in self.progress_bars:
self.progress_bars[category].update(step)
def write_text(self, category: str, text: str, step: int) -> None:
pass
class TensorboardWriter(StatsWriter):
def __init__(self, base_dir: str):
"""

"""
super().__init__()
self.summary_writers: Dict[str, tf.summary.FileWriter] = {}
self.base_dir: str = base_dir

:param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for
them.
"""
super().__init__()
# We need to keep track of the fields in the CSV, as all rows need the same fields.
self.csv_fields: Dict[str, List[str]] = {}
self.required_fields = required_fields if required_fields else []

writers: List[StatsWriter] = []
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))
def __init__(self, category):
def __init__(self, category: str):
"""
Generic StatsReporter. A category is the broadest type of storage (would
correspond the run name and trainer name, e.g. 3DBalltest_3DBall. A key is the

self.category: str = category
def set_property(self, key: str, value: Any) -> None:
"""
Add a generic property to all available writers.
:param key: The type of property.
:param value: The value of the property.
"""
for writer in StatsReporter.writers:
writer.set_property(self.category, key, value)
@staticmethod
def add_writer(writer: StatsWriter) -> None:

8
ml-agents/mlagents/trainers/trainer/trainer.py


from mlagents.trainers.policy import Policy
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.timers import hierarchical_timer
from mlagents.trainers.progress_bar import ProgressBar
logger = logging.getLogger("mlagents.trainers")

self.trainer_parameters = trainer_parameters
self.summary_path = trainer_parameters["summary_path"]
self.stats_reporter = StatsReporter(self.summary_path)
self.stats_reporter.set_property(
"max_steps", int(float(self.trainer_parameters["max_steps"]))
)
self.cumulative_returns_since_policy_update: List[float] = []
self.is_training = training
self._reward_buffer: Deque[float] = deque(maxlen=reward_buff_cap)

self.training_start_time = time.time()
self.summary_freq = self.trainer_parameters["summary_freq"]
self.next_summary_step = self.summary_freq
self.progress_bar = ProgressBar(
self.brain_name, "Steps", count=self.get_step, total=self.get_max_steps
)
def _check_param_keys(self):
for k in self.param_keys:

Saves training statistics to Tensorboard.
"""
self.stats_reporter.add_stat("Is Training", float(self.should_still_train))
self.progress_bar.update(step)
self.stats_reporter.write_stats(int(step))
@abc.abstractmethod

正在加载...
取消
保存