浏览代码

[change] Move hyperparameter printing entirely into StatsWriters (#3630)

/bug-failed-api-check
GitHub 5 年前
当前提交
25cc9f15
共有 6 个文件被更改,包括 117 次插入91 次删除
  1. 2
      com.unity.ml-agents/CHANGELOG.md
  2. 116
      ml-agents/mlagents/trainers/stats.py
  3. 3
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  4. 26
      ml-agents/mlagents/trainers/tests/test_stats.py
  5. 58
      ml-agents/mlagents/trainers/trainer/trainer.py
  6. 3
      ml-agents/mlagents/trainers/trainer_controller.py

2
com.unity.ml-agents/CHANGELOG.md


### Major Changes
### Minor Changes
- Format of console output has changed slightly and now matches the name of the model/summary directory. (#3630, #3616)
## [0.15.0-preview] - 2020-03-18
### Major Changes

116
ml-agents/mlagents/trainers/stats.py


from collections import defaultdict
from typing import List, Dict, NamedTuple
from enum import Enum
from typing import List, Dict, NamedTuple, Any
import numpy as np
import abc
import csv

from mlagents.tf_utils import tf
from mlagents.tf_utils import tf, generate_session_config
from mlagents_envs.timers import set_gauge
logger = logging.getLogger("mlagents.trainers")

return StatsSummary(0.0, 0.0, 0)
class StatsPropertyType(Enum):
HYPERPARAMETERS = "hyperparameters"
class StatsWriter(abc.ABC):
"""
A StatsWriter abstract class. A StatsWriter takes in a category, key, scalar value, and step

) -> None:
pass
@abc.abstractmethod
def write_text(self, category: str, text: str, step: int) -> None:
def add_property(
self, category: str, property_type: StatsPropertyType, value: Any
) -> None:
"""
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step, nor should
we write hyperparameters to the CSV.
:param category: The category that the property belongs to.
:param type: The type of property.
:param value: The property itself.
"""
pass

) -> None:
for val, stats_summary in values.items():
set_gauge(f"{category}.{val}.mean", float(stats_summary.mean))
def write_text(self, category: str, text: str, step: int) -> None:
pass
class ConsoleWriter(StatsWriter):

)
)
def write_text(self, category: str, text: str, step: int) -> None:
pass
def add_property(
self, category: str, property_type: StatsPropertyType, value: Any
) -> None:
if property_type == StatsPropertyType.HYPERPARAMETERS:
logger.info(
"""Hyperparameters for behavior name {0}: \n{1}""".format(
category, self._dict_to_str(value, 0)
)
)
def _dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str:
"""
Takes a parameter dictionary and converts it to a human-readable string.
Recurses if there are multiple levels of dict. Used to print out hyperparameters.
param: param_dict: A Dictionary of key, value parameters.
return: A string version of this dictionary.
"""
if not isinstance(param_dict, dict):
return str(param_dict)
else:
append_newline = "\n" if num_tabs > 0 else ""
return append_newline + "\n".join(
[
"\t"
+ " " * num_tabs
+ "{0}:\t{1}".format(
x, self._dict_to_str(param_dict[x], num_tabs + 1)
)
for x in param_dict
]
)
class TensorboardWriter(StatsWriter):

os.makedirs(filewriter_dir, exist_ok=True)
self.summary_writers[category] = tf.summary.FileWriter(filewriter_dir)
def write_text(self, category: str, text: str, step: int) -> None:
self._maybe_create_summary_writer(category)
self.summary_writers[category].add_summary(text, step)
def add_property(
self, category: str, property_type: StatsPropertyType, value: Any
) -> None:
if property_type == StatsPropertyType.HYPERPARAMETERS:
assert isinstance(value, dict)
text = self._dict_to_tensorboard("Hyperparameters", value)
self._maybe_create_summary_writer(category)
self.summary_writers[category].add_summary(text, 0)
def _dict_to_tensorboard(self, name: str, input_dict: Dict[str, Any]) -> str:
"""
Convert a dict to a Tensorboard-encoded string.
:param name: The name of the text.
:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
with tf.Session(config=generate_session_config()) as sess:
s_op = tf.summary.text(
name,
tf.convert_to_tensor(
([[str(x), str(input_dict[x])] for x in input_dict])
),
)
s = sess.run(s_op)
return s
except Exception:
logger.warning("Could not write text summary for Tensorboard.")
return ""
class CSVWriter(StatsWriter):

file_dir = os.path.join(self.base_dir, category + ".csv")
return file_dir
def write_text(self, category: str, text: str, step: int) -> None:
pass
def __init__(self, category):
def __init__(self, category: str):
"""
Generic StatsReporter. A category is the broadest type of storage (would
correspond the run name and trainer name, e.g. 3DBalltest_3DBall. A key is the

def add_writer(writer: StatsWriter) -> None:
StatsReporter.writers.append(writer)
def add_property(self, property_type: StatsPropertyType, value: Any) -> None:
"""
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step, nor should
we write hyperparameters to the CSV.
:param key: The type of property.
:param value: The property itself.
"""
for writer in StatsReporter.writers:
writer.add_property(self.category, property_type, value)
def add_stat(self, key: str, value: float) -> None:
"""
Add a float value stat to the StatsReporter.

for writer in StatsReporter.writers:
writer.write_stats(self.category, values, step)
del StatsReporter.stats_dict[self.category]
def write_text(self, text: str, step: int) -> None:
"""
Write out some text.
:param text: The text to write out.
:param step: Training step which to write these stats as.
"""
for writer in StatsReporter.writers:
writer.write_text(self.category, text, step)
def get_stats_summaries(self, key: str) -> StatsSummary:
"""

3
ml-agents/mlagents/trainers/tests/test_simple_rl.py


print(step, val, stats_summary.mean)
self._last_reward_summary[category] = stats_summary.mean
def write_text(self, category: str, text: str, step: int) -> None:
pass
def _check_environment_trains(
env,

26
ml-agents/mlagents/trainers/tests/test_stats.py


StatsSummary,
GaugeWriter,
ConsoleWriter,
StatsPropertyType,
)

)
def test_stat_reporter_text():
def test_stat_reporter_property():
# Test add_writer
mock_writer = mock.Mock()
StatsReporter.writers.clear()

statsreporter1 = StatsReporter("category1")
# Test write_text
step = 10
statsreporter1.write_text("this is a text", step)
mock_writer.write_text.assert_called_once_with("category1", "this is a text", step)
# Test add_property
statsreporter1.add_property("key", "this is a text")
mock_writer.add_property.assert_called_once_with(
"category1", "key", "this is a text"
)
@mock.patch("mlagents.tf_utils.tf.Summary")

)
mock_filewriter.return_value.flush.assert_called_once()
# Test hyperparameter writing - no good way to parse the TB string though.
tb_writer.add_property(
"category1", StatsPropertyType.HYPERPARAMETERS, {"example": 1.0}
)
assert mock_filewriter.return_value.add_summary.call_count > 1
def test_csv_writer():
# Test write_stats

},
10,
)
# Test hyperparameter writing - no good way to parse the TB string though.
console_writer.add_property(
"category1", StatsPropertyType.HYPERPARAMETERS, {"example": 1.0}
)
self.assertIn("Hyperparameters for behavior name", cm.output[2])
self.assertIn("example:\t1.0", cm.output[2])

58
ml-agents/mlagents/trainers/trainer/trainer.py


import time
import abc
from mlagents.tf_utils import tf
from mlagents import tf_utils
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.stats import StatsReporter, StatsPropertyType
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.brain import BrainParameters

self.training_start_time = time.time()
self.summary_freq = self.trainer_parameters["summary_freq"]
self.next_summary_step = self.summary_freq
self.stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_parameters
)
def _check_param_keys(self):
for k in self.param_keys:

"brain {2}.".format(k, self.__class__, self.brain_name)
)
def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None:
"""
Saves text to Tensorboard.
Note: Only works on tensorflow r1.2 or above.
:param key: The name of the text.
:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
with tf.Session(config=tf_utils.generate_session_config()) as sess:
s_op = tf.summary.text(
key,
tf.convert_to_tensor(
([[str(x), str(input_dict[x])] for x in input_dict])
),
)
s = sess.run(s_op)
self.stats_reporter.write_text(s, self.get_step)
except Exception:
logger.info("Could not write text summary for Tensorboard.")
pass
def _dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str:
"""
Takes a parameter dictionary and converts it to a human-readable string.
Recurses if there are multiple levels of dict. Used to print out hyperaparameters.
param: param_dict: A Dictionary of key, value parameters.
return: A string version of this dictionary.
"""
if not isinstance(param_dict, dict):
return str(param_dict)
else:
append_newline = "\n" if num_tabs > 0 else ""
return append_newline + "\n".join(
[
"\t"
+ " " * num_tabs
+ "{0}:\t{1}".format(
x, self._dict_to_str(param_dict[x], num_tabs + 1)
)
for x in param_dict
]
)
def __str__(self) -> str:
return """Hyperparameters for the {0} of brain {1}: \n{2}""".format(
self.__class__.__name__,
self.brain_name,
self._dict_to_str(self.trainer_parameters, 0),
)
@property
def parameters(self) -> Dict[str, Any]:

3
ml-agents/mlagents/trainers/trainer_controller.py


except KeyError:
trainer = self.trainer_factory.generate(brain_name)
self.trainers[brain_name] = trainer
self.logger.info(trainer)
if self.train_model:
trainer.write_tensorboard_text("Hyperparameters", trainer.parameters)
policy = trainer.create_policy(env_manager.external_brains[name_behavior_id])
trainer.add_policy(name_behavior_id, policy)

正在加载...
取消
保存