您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
117 行
4.2 KiB
117 行
4.2 KiB
from typing import Dict, Any
|
|
from enum import Enum
|
|
from collections import defaultdict
|
|
import json
|
|
import attr
|
|
import cattr
|
|
|
|
from mlagents.tf_utils import tf
|
|
from mlagents_envs.logging_util import get_logger
|
|
from mlagents.trainers import __version__
|
|
from mlagents.trainers.exception import TrainerError
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
STATUS_FORMAT_VERSION = "0.1.0"
|
|
|
|
|
|
class StatusType(Enum):
|
|
LESSON_NUM = "lesson_num"
|
|
STATS_METADATA = "metadata"
|
|
CHECKPOINTS = "checkpoints"
|
|
FINAL_CHECKPOINT = "final_checkpoint"
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class StatusMetaData:
|
|
stats_format_version: str = STATUS_FORMAT_VERSION
|
|
mlagents_version: str = __version__
|
|
tensorflow_version: str = tf.__version__
|
|
|
|
def to_dict(self) -> Dict[str, str]:
|
|
return cattr.unstructure(self)
|
|
|
|
@staticmethod
|
|
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData":
|
|
return cattr.structure(import_dict, StatusMetaData)
|
|
|
|
def check_compatibility(self, other: "StatusMetaData") -> None:
|
|
"""
|
|
Check compatibility with a loaded StatsMetaData and warn the user
|
|
if versions mismatch. This is used for resuming from old checkpoints.
|
|
"""
|
|
# This should cover all stats version mismatches as well.
|
|
if self.mlagents_version != other.mlagents_version:
|
|
logger.warning(
|
|
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly."
|
|
)
|
|
if self.tensorflow_version != other.tensorflow_version:
|
|
logger.warning(
|
|
"Tensorflow checkpoint was saved with a different version of Tensorflow. Model may not resume properly."
|
|
)
|
|
|
|
|
|
class GlobalTrainingStatus:
|
|
"""
|
|
GlobalTrainingStatus class that contains static methods to save global training status and
|
|
load it on a resume. These are values that might be needed for the training resume that
|
|
cannot/should not be captured in a model checkpoint, such as curriclum lesson.
|
|
"""
|
|
|
|
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {})
|
|
|
|
@staticmethod
|
|
def load_state(path: str) -> None:
|
|
"""
|
|
Load a JSON file that contains saved state.
|
|
:param path: Path to the JSON file containing the state.
|
|
"""
|
|
try:
|
|
with open(path) as f:
|
|
loaded_dict = json.load(f)
|
|
# Compare the metadata
|
|
_metadata = loaded_dict[StatusType.STATS_METADATA.value]
|
|
StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData())
|
|
# Update saved state.
|
|
GlobalTrainingStatus.saved_state.update(loaded_dict)
|
|
except FileNotFoundError:
|
|
logger.warning(
|
|
"Training status file not found. Not all functions will resume properly."
|
|
)
|
|
except KeyError:
|
|
raise TrainerError(
|
|
"Metadata not found, resuming from an incompatible version of ML-Agents."
|
|
)
|
|
|
|
@staticmethod
|
|
def save_state(path: str) -> None:
|
|
"""
|
|
Save a JSON file that contains saved state.
|
|
:param path: Path to the JSON file containing the state.
|
|
"""
|
|
GlobalTrainingStatus.saved_state[
|
|
StatusType.STATS_METADATA.value
|
|
] = StatusMetaData().to_dict()
|
|
with open(path, "w") as f:
|
|
json.dump(GlobalTrainingStatus.saved_state, f, indent=4)
|
|
|
|
@staticmethod
|
|
def set_parameter_state(category: str, key: StatusType, value: Any) -> None:
|
|
"""
|
|
Stores an arbitrary-named parameter in the global saved state.
|
|
:param category: The category (usually behavior name) of the parameter.
|
|
:param key: The parameter, e.g. lesson number.
|
|
:param value: The value.
|
|
"""
|
|
GlobalTrainingStatus.saved_state[category][key.value] = value
|
|
|
|
@staticmethod
|
|
def get_parameter_state(category: str, key: StatusType) -> Any:
|
|
"""
|
|
Loads an arbitrary-named parameter from training_status.json.
|
|
If not found, returns None.
|
|
:param category: The category (usually behavior name) of the parameter.
|
|
:param key: The statistic, e.g. lesson number.
|
|
:param value: The value.
|
|
"""
|
|
return GlobalTrainingStatus.saved_state[category].get(key.value, None)
|