GitHub
5 年前
当前提交
f5435876
共有 13 个文件被更改,包括 242 次插入 和 46 次删除
-
12com.unity.ml-agents/CHANGELOG.md
-
2docs/Migrating.md
-
8docs/Training-ML-Agents.md
-
2ml-agents/README.md
-
7ml-agents/mlagents/trainers/cli_utils.py
-
22ml-agents/mlagents/trainers/learn.py
-
25ml-agents/mlagents/trainers/meta_curriculum.py
-
1ml-agents/mlagents/trainers/settings.py
-
7ml-agents/mlagents/trainers/tests/test_learn.py
-
23ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
-
4ml-agents/mlagents/trainers/trainer_controller.py
-
60ml-agents/mlagents/trainers/tests/test_training_status.py
-
115ml-agents/mlagents/trainers/training_status.py
|
|||
import os |
|||
import unittest |
|||
import json |
|||
from enum import Enum |
|||
|
|||
from mlagents.trainers.training_status import ( |
|||
StatusType, |
|||
StatusMetaData, |
|||
GlobalTrainingStatus, |
|||
) |
|||
|
|||
|
|||
def test_globaltrainingstatus(tmpdir): |
|||
path_dir = os.path.join(tmpdir, "test.json") |
|||
|
|||
GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3) |
|||
GlobalTrainingStatus.save_state(path_dir) |
|||
|
|||
with open(path_dir, "r") as fp: |
|||
test_json = json.load(fp) |
|||
|
|||
assert "Category1" in test_json |
|||
assert StatusType.LESSON_NUM.value in test_json["Category1"] |
|||
assert test_json["Category1"][StatusType.LESSON_NUM.value] == 3 |
|||
assert "metadata" in test_json |
|||
|
|||
GlobalTrainingStatus.load_state(path_dir) |
|||
restored_val = GlobalTrainingStatus.get_parameter_state( |
|||
"Category1", StatusType.LESSON_NUM |
|||
) |
|||
assert restored_val == 3 |
|||
|
|||
# Test unknown categories and status types (keys) |
|||
unknown_category = GlobalTrainingStatus.get_parameter_state( |
|||
"Category3", StatusType.LESSON_NUM |
|||
) |
|||
|
|||
class FakeStatusType(Enum): |
|||
NOTAREALKEY = "notarealkey" |
|||
|
|||
unknown_key = GlobalTrainingStatus.get_parameter_state( |
|||
"Category1", FakeStatusType.NOTAREALKEY |
|||
) |
|||
assert unknown_category is None |
|||
assert unknown_key is None |
|||
|
|||
|
|||
class StatsMetaDataTest(unittest.TestCase): |
|||
def test_metadata_compare(self): |
|||
# Test write_stats |
|||
with self.assertLogs("mlagents.trainers", level="WARNING") as cm: |
|||
default_metadata = StatusMetaData() |
|||
version_statsmetadata = StatusMetaData(mlagents_version="test") |
|||
default_metadata.check_compatibility(version_statsmetadata) |
|||
|
|||
tf_version_statsmetadata = StatusMetaData(tensorflow_version="test") |
|||
default_metadata.check_compatibility(tf_version_statsmetadata) |
|||
|
|||
# Assert that 2 warnings have been thrown |
|||
assert len(cm.output) == 2 |
|
|||
from typing import Dict, Any |
|||
from enum import Enum |
|||
from collections import defaultdict |
|||
import json |
|||
import attr |
|||
import cattr |
|||
|
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers import __version__ |
|||
from mlagents.trainers.exception import TrainerError |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
STATUS_FORMAT_VERSION = "0.1.0" |
|||
|
|||
|
|||
class StatusType(Enum): |
|||
LESSON_NUM = "lesson_num" |
|||
STATS_METADATA = "metadata" |
|||
|
|||
|
|||
@attr.s(auto_attribs=True) |
|||
class StatusMetaData: |
|||
stats_format_version: str = STATUS_FORMAT_VERSION |
|||
mlagents_version: str = __version__ |
|||
tensorflow_version: str = tf.__version__ |
|||
|
|||
def to_dict(self) -> Dict[str, str]: |
|||
return cattr.unstructure(self) |
|||
|
|||
@staticmethod |
|||
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData": |
|||
return cattr.structure(import_dict, StatusMetaData) |
|||
|
|||
def check_compatibility(self, other: "StatusMetaData") -> None: |
|||
""" |
|||
Check compatibility with a loaded StatsMetaData and warn the user |
|||
if versions mismatch. This is used for resuming from old checkpoints. |
|||
""" |
|||
# This should cover all stats version mismatches as well. |
|||
if self.mlagents_version != other.mlagents_version: |
|||
logger.warning( |
|||
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly." |
|||
) |
|||
if self.tensorflow_version != other.tensorflow_version: |
|||
logger.warning( |
|||
"Tensorflow checkpoint was saved with a different version of Tensorflow. Model may not resume properly." |
|||
) |
|||
|
|||
|
|||
class GlobalTrainingStatus: |
|||
""" |
|||
GlobalTrainingStatus class that contains static methods to save global training status and |
|||
load it on a resume. These are values that might be needed for the training resume that |
|||
cannot/should not be captured in a model checkpoint, such as curriclum lesson. |
|||
""" |
|||
|
|||
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {}) |
|||
|
|||
@staticmethod |
|||
def load_state(path: str) -> None: |
|||
""" |
|||
Load a JSON file that contains saved state. |
|||
:param path: Path to the JSON file containing the state. |
|||
""" |
|||
try: |
|||
with open(path, "r") as f: |
|||
loaded_dict = json.load(f) |
|||
# Compare the metadata |
|||
_metadata = loaded_dict[StatusType.STATS_METADATA.value] |
|||
StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData()) |
|||
# Update saved state. |
|||
GlobalTrainingStatus.saved_state.update(loaded_dict) |
|||
except FileNotFoundError: |
|||
logger.warning( |
|||
"Training status file not found. Not all functions will resume properly." |
|||
) |
|||
except KeyError: |
|||
raise TrainerError( |
|||
"Metadata not found, resuming from an incompatible version of ML-Agents." |
|||
) |
|||
|
|||
@staticmethod |
|||
def save_state(path: str) -> None: |
|||
""" |
|||
Save a JSON file that contains saved state. |
|||
:param path: Path to the JSON file containing the state. |
|||
""" |
|||
GlobalTrainingStatus.saved_state[ |
|||
StatusType.STATS_METADATA.value |
|||
] = StatusMetaData().to_dict() |
|||
with open(path, "w") as f: |
|||
json.dump(GlobalTrainingStatus.saved_state, f, indent=4) |
|||
|
|||
@staticmethod |
|||
def set_parameter_state(category: str, key: StatusType, value: Any) -> None: |
|||
""" |
|||
Stores an arbitrary-named parameter in the global saved state. |
|||
:param category: The category (usually behavior name) of the parameter. |
|||
:param key: The parameter, e.g. lesson number. |
|||
:param value: The value. |
|||
""" |
|||
GlobalTrainingStatus.saved_state[category][key.value] = value |
|||
|
|||
@staticmethod |
|||
def get_parameter_state(category: str, key: StatusType) -> Any: |
|||
""" |
|||
Loads an arbitrary-named parameter from training_status.json. |
|||
If not found, returns None. |
|||
:param category: The category (usually behavior name) of the parameter. |
|||
:param key: The statistic, e.g. lesson number. |
|||
:param value: The value. |
|||
""" |
|||
return GlobalTrainingStatus.saved_state[category].get(key.value, None) |
撰写
预览
正在加载...
取消
保存
Reference in new issue