您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
101 行
3.9 KiB
101 行
3.9 KiB
# # Unity ML-Agents Toolkit
|
|
from typing import Dict, Any, Optional, List
|
|
import os
|
|
import attr
|
|
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
|
|
from mlagents_envs.logging_util import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class ModelCheckpoint:
|
|
steps: int
|
|
file_path: str
|
|
reward: Optional[float]
|
|
creation_time: float
|
|
auxillary_file_paths: List[str] = attr.ib(factory=list)
|
|
|
|
|
|
class ModelCheckpointManager:
|
|
@staticmethod
|
|
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
|
|
checkpoint_list = GlobalTrainingStatus.get_parameter_state(
|
|
behavior_name, StatusType.CHECKPOINTS
|
|
)
|
|
if not checkpoint_list:
|
|
checkpoint_list = []
|
|
GlobalTrainingStatus.set_parameter_state(
|
|
behavior_name, StatusType.CHECKPOINTS, checkpoint_list
|
|
)
|
|
return checkpoint_list
|
|
|
|
@staticmethod
|
|
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:
|
|
"""
|
|
Removes a checkpoint stored in checkpoint_list.
|
|
If checkpoint cannot be found, no action is done.
|
|
|
|
:param checkpoint: A checkpoint stored in checkpoint_list
|
|
"""
|
|
file_paths: List[str] = [checkpoint["file_path"]]
|
|
file_paths.extend(checkpoint["auxillary_file_paths"])
|
|
for file_path in file_paths:
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
logger.debug(f"Removed checkpoint model {file_path}.")
|
|
else:
|
|
logger.debug(f"Checkpoint at {file_path} could not be found.")
|
|
return
|
|
|
|
@classmethod
|
|
def _cleanup_extra_checkpoints(
|
|
cls, checkpoints: List[Dict], keep_checkpoints: int
|
|
) -> List[Dict]:
|
|
"""
|
|
Ensures that the number of checkpoints stored are within the number
|
|
of checkpoints the user defines. If the limit is hit, checkpoints are
|
|
removed to create room for the next checkpoint to be inserted.
|
|
|
|
:param behavior_name: The behavior name whose checkpoints we will mange.
|
|
:param keep_checkpoints: Number of checkpoints to record (user-defined).
|
|
"""
|
|
while len(checkpoints) > keep_checkpoints:
|
|
if keep_checkpoints <= 0 or len(checkpoints) == 0:
|
|
break
|
|
ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0))
|
|
return checkpoints
|
|
|
|
@classmethod
|
|
def add_checkpoint(
|
|
cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int
|
|
) -> None:
|
|
"""
|
|
Make room for new checkpoint if needed and insert new checkpoint information.
|
|
:param behavior_name: Behavior name for the checkpoint.
|
|
:param new_checkpoint: The new checkpoint to be recorded.
|
|
:param keep_checkpoints: Number of checkpoints to record (user-defined).
|
|
"""
|
|
new_checkpoint_dict = attr.asdict(new_checkpoint)
|
|
checkpoints = cls.get_checkpoints(behavior_name)
|
|
checkpoints.append(new_checkpoint_dict)
|
|
cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints)
|
|
GlobalTrainingStatus.set_parameter_state(
|
|
behavior_name, StatusType.CHECKPOINTS, checkpoints
|
|
)
|
|
|
|
@classmethod
|
|
def track_final_checkpoint(
|
|
cls, behavior_name: str, final_checkpoint: ModelCheckpoint
|
|
) -> None:
|
|
"""
|
|
Ensures number of checkpoints stored is within the max number of checkpoints
|
|
defined by the user and finally stores the information about the final
|
|
model (or intermediate model if training is interrupted).
|
|
:param behavior_name: Behavior name of the model.
|
|
:param final_checkpoint: Checkpoint information for the final model.
|
|
"""
|
|
final_model_dict = attr.asdict(final_checkpoint)
|
|
GlobalTrainingStatus.set_parameter_state(
|
|
behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict
|
|
)
|