浏览代码
Merge pull request #1858 from Unity-Technologies/develop-esh-metrics
Merge pull request #1858 from Unity-Technologies/develop-esh-metrics
Added logging per Brain of time to update policy, time elapsed during training, time to collect experiences, buffer length, average return per policy/develop-generalizationTraining-TrainerController
GitHub
6 年前
当前提交
a0b44f1b
共有 15 个文件被更改,包括 277 次插入 和 61 次删除
-
4docs/Training-ML-Agents.md
-
1ml-agents/mlagents/trainers/__init__.py
-
3ml-agents/mlagents/trainers/bc/offline_trainer.py
-
3ml-agents/mlagents/trainers/bc/online_trainer.py
-
14ml-agents/mlagents/trainers/bc/trainer.py
-
23ml-agents/mlagents/trainers/learn.py
-
30ml-agents/mlagents/trainers/ppo/trainer.py
-
1ml-agents/mlagents/trainers/tests/test_learn.py
-
2ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
81ml-agents/mlagents/trainers/trainer.py
-
30ml-agents/mlagents/trainers/trainer_controller.py
-
107ml-agents/mlagents/trainers/trainer_metrics.py
-
39ml-agents/tests/trainers/test_trainer_metrics.py
|
|||
# # Unity ML-Agents Toolkit |
|||
import logging |
|||
import csv |
|||
from time import time |
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
FIELD_NAMES = ['Brain name', 'Time to update policy', |
|||
'Time since start of training', 'Time for last experience collection', |
|||
'Number of experiences used for training', 'Mean return'] |
|||
|
|||
class TrainerMetrics: |
|||
""" |
|||
Helper class to track, write training metrics. Tracks time since object |
|||
of this class is initialized. |
|||
""" |
|||
def __init__(self, path: str, brain_name: str): |
|||
""" |
|||
:str path: Fully qualified path where CSV is stored. |
|||
:str brain_name: Identifier for the Brain which we are training |
|||
""" |
|||
self.path = path |
|||
self.brain_name = brain_name |
|||
self.rows = [] |
|||
self.time_start_experience_collection = None |
|||
self.time_training_start = time() |
|||
self.last_buffer_length = None |
|||
self.last_mean_return = None |
|||
self.time_policy_update_start = None |
|||
self.delta_last_experience_collection = None |
|||
self.delta_policy_update = None |
|||
|
|||
def start_experience_collection_timer(self): |
|||
""" |
|||
Inform Metrics class that experience collection is starting. Intended to be idempotent |
|||
""" |
|||
if self.time_start_experience_collection is None: |
|||
self.time_start_experience_collection = time() |
|||
|
|||
def end_experience_collection_timer(self): |
|||
""" |
|||
Inform Metrics class that experience collection is done. |
|||
""" |
|||
if self.time_start_experience_collection: |
|||
curr_delta = time() - self.time_start_experience_collection |
|||
if self.delta_last_experience_collection is None: |
|||
self.delta_last_experience_collection = curr_delta |
|||
else: |
|||
self.delta_last_experience_collection += curr_delta |
|||
self.time_start_experience_collection = None |
|||
|
|||
def add_delta_step(self, delta: float): |
|||
""" |
|||
Inform Metrics class about time to step in environment. |
|||
""" |
|||
if self.delta_last_experience_collection: |
|||
self.delta_last_experience_collection += delta |
|||
else: |
|||
self.delta_last_experience_collection = delta |
|||
|
|||
def start_policy_update_timer(self, number_experiences: int, mean_return: float): |
|||
""" |
|||
Inform Metrics class that policy update has started. |
|||
:int number_experiences: Number of experiences in Buffer at this point. |
|||
:float mean_return: Return averaged across all cumulative returns since last policy update |
|||
""" |
|||
self.last_buffer_length = number_experiences |
|||
self.last_mean_return = mean_return |
|||
self.time_policy_update_start = time() |
|||
|
|||
def _add_row(self, delta_train_start): |
|||
row = [self.brain_name] |
|||
row.extend(format(c, '.3f') if isinstance(c, float) else c |
|||
for c in [self.delta_policy_update, delta_train_start, |
|||
self.delta_last_experience_collection, |
|||
self.last_buffer_length, self.last_mean_return]) |
|||
self.delta_last_experience_collection = None |
|||
self.rows.append(row) |
|||
|
|||
def end_policy_update(self): |
|||
""" |
|||
Inform Metrics class that policy update has started. |
|||
""" |
|||
if self.time_policy_update_start: |
|||
self.delta_policy_update = time() - self.time_policy_update_start |
|||
else: |
|||
self.delta_policy_update = 0 |
|||
delta_train_start = time() - self.time_training_start |
|||
LOGGER.debug(" Policy Update Training Metrics for {}: " |
|||
"\n\t\tTime to update Policy: {:0.3f} s \n" |
|||
"\t\tTime elapsed since training: {:0.3f} s \n" |
|||
"\t\tTime for experience collection: {:0.3f} s \n" |
|||
"\t\tBuffer Length: {} \n" |
|||
"\t\tReturns : {:0.3f}\n" |
|||
.format(self.brain_name, self.delta_policy_update, |
|||
delta_train_start, self.delta_last_experience_collection, |
|||
self.last_buffer_length, self.last_mean_return)) |
|||
self._add_row(delta_train_start) |
|||
|
|||
def write_training_metrics(self): |
|||
""" |
|||
Write Training Metrics to CSV |
|||
""" |
|||
with open(self.path, 'w') as file: |
|||
writer = csv.writer(file) |
|||
writer.writerow(FIELD_NAMES) |
|||
for row in self.rows: |
|||
writer.writerow(row) |
|
|||
import unittest.mock as mock |
|||
from mlagents.trainers import TrainerMetrics |
|||
|
|||
class TestTrainerMetrics: |
|||
|
|||
def test_field_names(self): |
|||
field_names = ['Brain name', 'Time to update policy', |
|||
'Time since start of training', |
|||
'Time for last experience collection', |
|||
'Number of experiences used for training', 'Mean return'] |
|||
from mlagents.trainers.trainer_metrics import FIELD_NAMES |
|||
assert FIELD_NAMES == field_names |
|||
|
|||
@mock.patch('mlagents.trainers.trainer_metrics.time', mock.MagicMock(return_value=42)) |
|||
def test_experience_collection_timer(self): |
|||
mock_path = 'fake' |
|||
mock_brain_name = 'fake' |
|||
trainer_metrics = TrainerMetrics(path=mock_path, |
|||
brain_name=mock_brain_name) |
|||
trainer_metrics.start_experience_collection_timer() |
|||
trainer_metrics.end_experience_collection_timer() |
|||
assert trainer_metrics.delta_last_experience_collection == 0 |
|||
|
|||
@mock.patch('mlagents.trainers.trainer_metrics.time', mock.MagicMock(return_value=42)) |
|||
def test_policy_update_timer(self): |
|||
mock_path = 'fake' |
|||
mock_brain_name = 'fake' |
|||
fake_buffer_length = 350 |
|||
fake_mean_return = 0.3 |
|||
trainer_metrics = TrainerMetrics(path=mock_path, |
|||
brain_name=mock_brain_name) |
|||
trainer_metrics.start_experience_collection_timer() |
|||
trainer_metrics.end_experience_collection_timer() |
|||
trainer_metrics.start_policy_update_timer(number_experiences=fake_buffer_length, |
|||
mean_return=fake_mean_return) |
|||
trainer_metrics.end_policy_update() |
|||
fake_row = [mock_brain_name, 0,0, 0, 350, '0.300'] |
|||
assert trainer_metrics.rows[0] == fake_row |
|||
|
撰写
预览
正在加载...
取消
保存
Reference in new issue