GitHub
5 年前
当前提交
2ac242f7
共有 10 个文件被更改,包括 145 次插入 和 229 次删除
-
10ml-agents/mlagents/trainers/learn.py
-
5ml-agents/mlagents/trainers/ppo/trainer.py
-
5ml-agents/mlagents/trainers/sac/trainer.py
-
94ml-agents/mlagents/trainers/stats.py
-
59ml-agents/mlagents/trainers/tests/test_stats.py
-
3ml-agents/mlagents/trainers/tests/test_trainer_util.py
-
11ml-agents/mlagents/trainers/trainer.py
-
19ml-agents/mlagents/trainers/trainer_controller.py
-
46ml-agents/mlagents/trainers/tests/test_trainer_metrics.py
-
122ml-agents/mlagents/trainers/trainer_metrics.py
|
|||
import unittest.mock as mock |
|||
from mlagents.trainers.trainer_metrics import TrainerMetrics |
|||
|
|||
|
|||
class TestTrainerMetrics: |
|||
def test_field_names(self): |
|||
field_names = [ |
|||
"Brain name", |
|||
"Time to update policy", |
|||
"Time since start of training", |
|||
"Time for last experience collection", |
|||
"Number of experiences used for training", |
|||
"Mean return", |
|||
] |
|||
from mlagents.trainers.trainer_metrics import FIELD_NAMES |
|||
|
|||
assert FIELD_NAMES == field_names |
|||
|
|||
@mock.patch( |
|||
"mlagents.trainers.trainer_metrics.time", mock.MagicMock(return_value=42) |
|||
) |
|||
def test_experience_collection_timer(self): |
|||
mock_path = "fake" |
|||
mock_brain_name = "fake" |
|||
trainer_metrics = TrainerMetrics(path=mock_path, brain_name=mock_brain_name) |
|||
trainer_metrics.start_experience_collection_timer() |
|||
trainer_metrics.end_experience_collection_timer() |
|||
assert trainer_metrics.delta_last_experience_collection == 0 |
|||
|
|||
@mock.patch( |
|||
"mlagents.trainers.trainer_metrics.time", mock.MagicMock(return_value=42) |
|||
) |
|||
def test_policy_update_timer(self): |
|||
mock_path = "fake" |
|||
mock_brain_name = "fake" |
|||
fake_buffer_length = 350 |
|||
fake_mean_return = 0.3 |
|||
trainer_metrics = TrainerMetrics(path=mock_path, brain_name=mock_brain_name) |
|||
trainer_metrics.start_experience_collection_timer() |
|||
trainer_metrics.end_experience_collection_timer() |
|||
trainer_metrics.start_policy_update_timer( |
|||
number_experiences=fake_buffer_length, mean_return=fake_mean_return |
|||
) |
|||
trainer_metrics.end_policy_update() |
|||
fake_row = [mock_brain_name, 0, 0, 0, 350, "0.300"] |
|||
assert trainer_metrics.rows[0] == fake_row |
|
|||
# # Unity ML-Agents Toolkit |
|||
import logging |
|||
import csv |
|||
from time import time |
|||
from typing import List, Optional |
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
FIELD_NAMES = [ |
|||
"Brain name", |
|||
"Time to update policy", |
|||
"Time since start of training", |
|||
"Time for last experience collection", |
|||
"Number of experiences used for training", |
|||
"Mean return", |
|||
] |
|||
|
|||
|
|||
class TrainerMetrics: |
|||
""" |
|||
Helper class to track, write training metrics. Tracks time since object |
|||
of this class is initialized. |
|||
""" |
|||
|
|||
def __init__(self, path: str, brain_name: str): |
|||
""" |
|||
:str path: Fully qualified path where CSV is stored. |
|||
:str brain_name: Identifier for the Brain which we are training |
|||
""" |
|||
self.path = path |
|||
self.brain_name = brain_name |
|||
self.rows: List[List[Optional[str]]] = [] |
|||
self.time_start_experience_collection: Optional[float] = None |
|||
self.time_training_start = time() |
|||
self.last_buffer_length: Optional[int] = None |
|||
self.last_mean_return: Optional[float] = None |
|||
self.time_policy_update_start: Optional[float] = None |
|||
self.delta_last_experience_collection: Optional[float] = None |
|||
self.delta_policy_update: Optional[float] = None |
|||
|
|||
def start_experience_collection_timer(self) -> None: |
|||
""" |
|||
Inform Metrics class that experience collection is starting. Intended to be idempotent |
|||
""" |
|||
if self.time_start_experience_collection is None: |
|||
self.time_start_experience_collection = time() |
|||
|
|||
def end_experience_collection_timer(self) -> None: |
|||
""" |
|||
Inform Metrics class that experience collection is done. |
|||
""" |
|||
if self.time_start_experience_collection: |
|||
curr_delta = time() - self.time_start_experience_collection |
|||
if self.delta_last_experience_collection is None: |
|||
self.delta_last_experience_collection = curr_delta |
|||
else: |
|||
self.delta_last_experience_collection += curr_delta |
|||
self.time_start_experience_collection = None |
|||
|
|||
def add_delta_step(self, delta: float) -> None: |
|||
""" |
|||
Inform Metrics class about time to step in environment. |
|||
""" |
|||
if self.delta_last_experience_collection: |
|||
self.delta_last_experience_collection += delta |
|||
else: |
|||
self.delta_last_experience_collection = delta |
|||
|
|||
def start_policy_update_timer( |
|||
self, number_experiences: int, mean_return: float |
|||
) -> None: |
|||
""" |
|||
Inform Metrics class that policy update has started. |
|||
:int number_experiences: Number of experiences in Buffer at this point. |
|||
:float mean_return: Return averaged across all cumulative returns since last policy update |
|||
""" |
|||
self.last_buffer_length = number_experiences |
|||
self.last_mean_return = mean_return |
|||
self.time_policy_update_start = time() |
|||
|
|||
def _add_row(self, delta_train_start: float) -> None: |
|||
row: List[Optional[str]] = [self.brain_name] |
|||
row.extend( |
|||
format(c, ".3f") if isinstance(c, float) else c |
|||
for c in [ |
|||
self.delta_policy_update, |
|||
delta_train_start, |
|||
self.delta_last_experience_collection, |
|||
self.last_buffer_length, |
|||
self.last_mean_return, |
|||
] |
|||
) |
|||
self.delta_last_experience_collection = None |
|||
self.rows.append(row) |
|||
|
|||
def end_policy_update(self) -> None: |
|||
""" |
|||
Inform Metrics class that policy update has started. |
|||
""" |
|||
if self.time_policy_update_start: |
|||
self.delta_policy_update = time() - self.time_policy_update_start |
|||
else: |
|||
self.delta_policy_update = 0 |
|||
delta_train_start = time() - self.time_training_start |
|||
LOGGER.debug( |
|||
f" Policy Update Training Metrics for {self.brain_name}: " |
|||
f"\n\t\tTime to update Policy: {self.delta_policy_update:0.3f} s \n" |
|||
f"\t\tTime elapsed since training: {delta_train_start:0.3f} s \n" |
|||
f"\t\tTime for experience collection: {(self.delta_last_experience_collection or 0):0.3f} s \n" |
|||
f"\t\tBuffer Length: {self.last_buffer_length or 0} \n" |
|||
f"\t\tReturns : {(self.last_mean_return or 0):0.3f}\n" |
|||
) |
|||
self._add_row(delta_train_start) |
|||
|
|||
def write_training_metrics(self) -> None: |
|||
""" |
|||
Write Training Metrics to CSV |
|||
""" |
|||
with open(self.path, "w") as file: |
|||
writer = csv.writer(file) |
|||
writer.writerow(FIELD_NAMES) |
|||
for row in self.rows: |
|||
writer.writerow(row) |
撰写
预览
正在加载...
取消
保存
Reference in new issue