浏览代码

Make progress bar class and add to trainer

/develop/progress-bar
Ervin Teng 5 年前
当前提交
ce6ab0de
共有 2 个文件被更改,包括 24 次插入0 次删除
  1. 5
      ml-agents/mlagents/trainers/trainer/trainer.py
  2. 19
      ml-agents/mlagents/trainers/progress_bar.py

5
ml-agents/mlagents/trainers/trainer/trainer.py


from mlagents.trainers.policy import Policy
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.timers import hierarchical_timer
from mlagents.trainers.progress_bar import ProgressBar
logger = logging.getLogger("mlagents.trainers")

self.training_start_time = time.time()
self.summary_freq = self.trainer_parameters["summary_freq"]
self.next_summary_step = self.summary_freq
self.progress_bar = ProgressBar(
self.brain_name, "Steps", count=self.get_step, total=self.get_max_steps
)
def _check_param_keys(self):
for k in self.param_keys:

self.run_id, self.brain_name, step, is_training
)
)
self.progress_bar.update(step)
self.stats_reporter.write_stats(int(step))
@abc.abstractmethod

19
ml-agents/mlagents/trainers/progress_bar.py


import enlighten
class ProgressBar:
manager = enlighten.get_manager()
def __init__(self, name: str, unit: str, count: int, total: int):
self.progress_bar = ProgressBar.manager.counter(
desc=name, unit=unit, count=count, total=total
)
self.progress_bar.refresh()
def update(self, count: int) -> None:
"""
Updates the progress bar based on the current count.
:param count: Current count.
"""
self.progress_bar.count = count
self.progress_bar.refresh()
正在加载...
取消
保存