您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
69 行
2.5 KiB
69 行
2.5 KiB
# # Unity ML-Agents Toolkit
|
|
import abc
|
|
from typing import Any, Tuple, List
|
|
|
|
|
|
class BaseModelSaver(abc.ABC):
|
|
"""This class is the base class for the ModelSaver"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def register(self, module: Any) -> None:
|
|
"""
|
|
Register the modules to the ModelSaver.
|
|
The ModelSaver will store the module and include it in the saved files
|
|
when saving checkpoint/exporting graph.
|
|
:param module: the module to be registered
|
|
"""
|
|
pass
|
|
|
|
def _register_policy(self, policy):
|
|
"""
|
|
Helper function for registering policy to the ModelSaver.
|
|
:param policy: the policy to be registered
|
|
"""
|
|
pass
|
|
|
|
def _register_optimizer(self, optimizer):
|
|
"""
|
|
Helper function for registering optimizer to the ModelSaver.
|
|
:param optimizer: the optimizer to be registered
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]:
|
|
"""
|
|
Checkpoints the policy on disk.
|
|
:param checkpoint_path: filepath to write the checkpoint
|
|
:param behavior_name: Behavior name of bevavior to be trained
|
|
:return: A Tuple of the path to the exported file, as well as a List of any
|
|
auxillary files that were returned. For instance, an exported file would be Model.onnx,
|
|
and the auxillary files would be [Model.pt] for PyTorch
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def export(self, output_filepath: str, behavior_name: str) -> None:
|
|
"""
|
|
Saves the serialized model, given a path and behavior name.
|
|
This method will save the policy graph to the given filepath. The path
|
|
should be provided without an extension as multiple serialized model formats
|
|
may be generated as a result.
|
|
:param output_filepath: path (without suffix) for the model file(s)
|
|
:param behavior_name: Behavior name of behavior to be trained.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def initialize_or_load(self, policy):
|
|
"""
|
|
Initialize/Load registered modules by default.
|
|
If given input argument policy, do with the input policy instead.
|
|
This argument is mainly for the initialization of the ghost trainer's fixed policy.
|
|
:param policy (optional): if given, perform the initializing/loading on this input policy.
|
|
Otherwise, do with the registered policy
|
|
"""
|
|
pass
|