您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
82 行
3.1 KiB
82 行
3.1 KiB
import os
|
|
|
|
import torch
|
|
from mlagents_envs.logging_util import get_logger
|
|
from mlagents.trainers.saver.saver import Saver
|
|
from mlagents_envs.base_env import BehaviorSpec
|
|
from mlagents.trainers.settings import TrainerSettings
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy
|
|
from mlagents.trainers.torch.model_serialization import ModelSerializer
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class TorchSaver(Saver):
|
|
"""
|
|
Saver class for PyTorch
|
|
"""
|
|
def __init__(
|
|
self,
|
|
policy: TorchPolicy,
|
|
trainer_settings: TrainerSettings,
|
|
model_path: str,
|
|
load: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.policy = policy
|
|
self.model_path = model_path
|
|
self.initialize_path = trainer_settings.init_path
|
|
self._keep_checkpoints = trainer_settings.keep_checkpoints
|
|
self.load = load
|
|
self.exporter = ModelSerializer(self.policy)
|
|
|
|
self.modules = {}
|
|
|
|
def register(self, module):
|
|
self.modules.update(module.get_modules())
|
|
|
|
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
|
|
"""
|
|
Checkpoints the policy on disk.
|
|
|
|
:param checkpoint_path: filepath to write the checkpoint
|
|
:param brain_name: Brain name of brain to be trained
|
|
"""
|
|
print('save checkpoint_path:', checkpoint_path)
|
|
if not os.path.exists(self.model_path):
|
|
os.makedirs(self.model_path)
|
|
state_dict = {name: module.state_dict() for name, module in self.modules.items()}
|
|
torch.save(state_dict, f"{checkpoint_path}.pt")
|
|
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
|
|
self.export(checkpoint_path, brain_name)
|
|
|
|
def maybe_load(self):
|
|
# If there is an initialize path, load from that. Else, load from the set model path.
|
|
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
|
|
# e.g., resume from an initialize path.
|
|
reset_steps = not self.load
|
|
if self.initialize_path is not None:
|
|
self._load_model(self.initialize_path, reset_global_steps=reset_steps)
|
|
elif self.load:
|
|
self._load_model(self.model_path, reset_global_steps=reset_steps)
|
|
|
|
def export(self, output_filepath: str, brain_name: str) -> None:
|
|
print('export output_filepath:', output_filepath)
|
|
self.exporter.export_policy_model(output_filepath)
|
|
|
|
def _load_model(self, load_path: str, reset_global_steps: bool = False) -> None:
|
|
model_path = os.path.join(load_path, "checkpoint.pt")
|
|
print('load model_path:', model_path)
|
|
saved_state_dict = torch.load(model_path)
|
|
for name, state_dict in saved_state_dict.items():
|
|
self.modules[name].load_state_dict(state_dict)
|
|
if reset_global_steps:
|
|
self.policy._set_step(0)
|
|
logger.info(
|
|
"Starting training from step 0 and saving to {}.".format(
|
|
self.model_path
|
|
)
|
|
)
|
|
else:
|
|
logger.info(f"Resuming training from step {self.policy.get_current_step()}.")
|