浏览代码
Added logging per Brain of time to update policy, time elapsed during training, time to collect experiences, buffer length, average return
/develop-generalizationTraining-TrainerController
Added logging per Brain of time to update policy, time elapsed during training, time to collect experiences, buffer length, average return
/develop-generalizationTraining-TrainerController
eshvk
6 年前
当前提交
cc9bdf17
共有 14 个文件被更改,包括 252 次插入 和 55 次删除
-
5docs/Training-ML-Agents.md
-
1ml-agents/mlagents/trainers/__init__.py
-
3ml-agents/mlagents/trainers/bc/offline_trainer.py
-
3ml-agents/mlagents/trainers/bc/online_trainer.py
-
12ml-agents/mlagents/trainers/bc/trainer.py
-
23ml-agents/mlagents/trainers/learn.py
-
23ml-agents/mlagents/trainers/ppo/trainer.py
-
1ml-agents/mlagents/trainers/tests/test_learn.py
-
78ml-agents/mlagents/trainers/trainer.py
-
20ml-agents/mlagents/trainers/trainer_controller.py
-
98ml-agents/mlagents/trainers/trainer_metrics.py
-
40ml-agents/tests/trainers/test_trainer_metrics.py
|
|||
# # Unity ML-Agents Toolkit |
|||
import logging |
|||
import csv |
|||
from time import time |
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
|
|||
class TrainerMetrics: |
|||
""" |
|||
Helper class to track, write training metrics. Tracks time since object |
|||
of this class is initialized. |
|||
""" |
|||
|
|||
def __init__(self, path: str, brain_name: str): |
|||
""" |
|||
:str path: Fully qualified path where CSV is stored. |
|||
:str brain_name: Identifier for the Brain which we are training |
|||
""" |
|||
self.path = path |
|||
self.brain_name = brain_name |
|||
self.FIELD_NAMES = ['Brain name', 'Time to update policy', |
|||
'Time since start of training', 'Time for last experience collection', |
|||
'Number of experiences used for training', 'Mean return'] |
|||
self.rows = [] |
|||
self.time_start_experience_collection = None |
|||
self.time_training_start = time() |
|||
self.last_buffer_length = None |
|||
self.last_mean_return = None |
|||
self.time_policy_update_start = None |
|||
self.delta_last_experience_collection = None |
|||
self.delta_policy_update = None |
|||
|
|||
def start_experience_collection_timer(self): |
|||
""" |
|||
Inform Metrics class that experience collection is starting. Intended to be idempotent |
|||
""" |
|||
if self.time_start_experience_collection is None: |
|||
self.time_start_experience_collection = time() |
|||
|
|||
def end_experience_collection_timer(self): |
|||
""" |
|||
Inform Metrics class that experience collection is done. |
|||
""" |
|||
if self.start_experience_collection_timer: |
|||
self.delta_last_experience_collection = time() - self.time_start_experience_collection |
|||
else: |
|||
self.delta_last_experience_collection = 0.0 |
|||
self.time_start_experience_collection = None |
|||
|
|||
def start_policy_update_timer(self, number_experiences: int, mean_return: float): |
|||
""" |
|||
Inform Metrics class that policy update has started. |
|||
:int number_experiences: Number of experiences in Buffer at this point. |
|||
:float mean_return: Return averaged across all cumulative returns since last policy update |
|||
""" |
|||
self.last_buffer_length = number_experiences |
|||
self.last_mean_return = mean_return |
|||
self.time_policy_update_start = time() |
|||
|
|||
def _add_row(self, delta_train_start): |
|||
row = [self.brain_name] |
|||
row.extend(format(c, '.3f') if isinstance(c, float) else c |
|||
for c in [self.delta_policy_update, delta_train_start, |
|||
self.delta_last_experience_collection, |
|||
self.last_buffer_length, self.last_mean_return]) |
|||
self.rows.append(row) |
|||
|
|||
|
|||
def end_policy_update(self): |
|||
""" |
|||
Inform Metrics class that policy update has started. |
|||
""" |
|||
if self.time_policy_update_start: |
|||
self.delta_policy_update = time() - self.time_policy_update_start |
|||
else: |
|||
self.delta_policy_update = 0 |
|||
delta_train_start = time() - self.time_training_start |
|||
LOGGER.debug(" Policy Update Training Metrics for {}: " |
|||
"\n\t\tTime to update Policy: {:0.3f} s \n" |
|||
"\t\tTime elapsed since training: {:0.3f} s \n" |
|||
"\t\tTime for experience collection: {:0.3f} s \n" |
|||
"\t\tBuffer Length: {} \n" |
|||
"\t\tReturns : {:0.3f}\n" |
|||
.format(self.brain_name, self.delta_policy_update, |
|||
delta_train_start, self.delta_last_experience_collection, |
|||
self.last_buffer_length, self.last_mean_return)) |
|||
self._add_row(delta_train_start) |
|||
|
|||
|
|||
def write_training_metrics(self): |
|||
""" |
|||
Write Training Metrics to CSV |
|||
""" |
|||
with open(self.path, 'w') as f: |
|||
writer = csv.writer(f) |
|||
writer.writerow(self.FIELD_NAMES) |
|||
for row in self.rows: |
|||
writer.writerow(row) |
|
|||
import unittest.mock as mock |
|||
from mlagents.trainers import TrainerMetrics |
|||
|
|||
class TestTrainerMetrics: |
|||
|
|||
def test_field_names(self): |
|||
field_names = ['Brain name', 'Time to update policy', |
|||
'Time since start of training', 'Time for last experience collection', 'Number of experiences used for training', 'Mean return'] |
|||
mock_path = 'fake' |
|||
mock_brain_name = 'fake' |
|||
trainer_metrics = TrainerMetrics(path=mock_path, |
|||
brain_name=mock_brain_name) |
|||
assert trainer_metrics.FIELD_NAMES == field_names |
|||
|
|||
@mock.patch('mlagents.trainers.trainer_metrics.time', mock.MagicMock(return_value=42)) |
|||
def test_experience_collection_timer(self): |
|||
mock_path = 'fake' |
|||
mock_brain_name = 'fake' |
|||
trainer_metrics = TrainerMetrics(path=mock_path, |
|||
brain_name=mock_brain_name) |
|||
trainer_metrics.start_experience_collection_timer() |
|||
trainer_metrics.end_experience_collection_timer() |
|||
assert trainer_metrics.delta_last_experience_collection == 0 |
|||
|
|||
@mock.patch('mlagents.trainers.trainer_metrics.time', mock.MagicMock(return_value=42)) |
|||
def test_policy_update_timer(self): |
|||
mock_path = 'fake' |
|||
mock_brain_name = 'fake' |
|||
fake_buffer_length = 350 |
|||
fake_mean_return = 0.3 |
|||
trainer_metrics = TrainerMetrics(path=mock_path, |
|||
brain_name=mock_brain_name) |
|||
trainer_metrics.start_experience_collection_timer() |
|||
trainer_metrics.end_experience_collection_timer() |
|||
trainer_metrics.start_policy_update_timer(number_experiences=fake_buffer_length, |
|||
mean_return=fake_mean_return) |
|||
trainer_metrics.end_policy_update() |
|||
fake_row = [mock_brain_name, 0,0, 0, 350, '0.300'] |
|||
assert trainer_metrics.rows[0] == fake_row |
|||
|
撰写
预览
正在加载...
取消
保存
Reference in new issue