Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

98 行
3.7 KiB

# # Unity ML-Agents Toolkit
from typing import Dict, Any, Optional, List
import os
import attr
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
from mlagents_envs.logging_util import get_logger
logger = get_logger(__name__)
@attr.s(auto_attribs=True)
class ModelCheckpoint:
steps: int
file_path: str
reward: Optional[float]
creation_time: float
class ModelCheckpointManager:
@staticmethod
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
checkpoint_list = GlobalTrainingStatus.get_parameter_state(
behavior_name, StatusType.CHECKPOINTS
)
if not checkpoint_list:
checkpoint_list = []
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.CHECKPOINTS, checkpoint_list
)
return checkpoint_list
@staticmethod
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:
"""
Removes a checkpoint stored in checkpoint_list.
If checkpoint cannot be found, no action is done.
:param checkpoint: A checkpoint stored in checkpoint_list
"""
file_path: str = checkpoint["file_path"]
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Removed checkpoint model {file_path}.")
else:
logger.info(f"Checkpoint at {file_path} could not be found.")
return
@classmethod
def _cleanup_extra_checkpoints(
cls, checkpoints: List[Dict], keep_checkpoints: int
) -> List[Dict]:
"""
Ensures that the number of checkpoints stored are within the number
of checkpoints the user defines. If the limit is hit, checkpoints are
removed to create room for the next checkpoint to be inserted.
:param behavior_name: The behavior name whose checkpoints we will mange.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
while len(checkpoints) > keep_checkpoints:
if keep_checkpoints <= 0 or len(checkpoints) == 0:
break
ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0))
return checkpoints
@classmethod
def add_checkpoint(
cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int
) -> None:
"""
Make room for new checkpoint if needed and insert new checkpoint information.
:param behavior_name: Behavior name for the checkpoint.
:param new_checkpoint: The new checkpoint to be recorded.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
new_checkpoint_dict = attr.asdict(new_checkpoint)
checkpoints = cls.get_checkpoints(behavior_name)
checkpoints.append(new_checkpoint_dict)
cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints)
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.CHECKPOINTS, checkpoints
)
@classmethod
def track_final_checkpoint(
cls, behavior_name: str, final_checkpoint: ModelCheckpoint
) -> None:
"""
Ensures number of checkpoints stored is within the max number of checkpoints
defined by the user and finally stores the information about the final
model (or intermediate model if training is interrupted).
:param behavior_name: Behavior name of the model.
:param final_checkpoint: Checkpoint information for the final model.
"""
final_model_dict = attr.asdict(final_checkpoint)
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict
)