您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
135 行
5.5 KiB
135 行
5.5 KiB
from typing import Tuple
|
|
from distutils.version import LooseVersion
|
|
from mlagents_envs.exception import UnityException
|
|
from mlagents_envs.logging_util import get_logger
|
|
from mlagents.tf_utils import tf
|
|
from mlagents.trainers.saver.saver import Saver
|
|
from mlagents.trainers.tf.model_serialization import export_policy_model
|
|
from mlagents_envs.base_env import BehaviorSpec
|
|
from mlagents.trainers.settings import TrainerSettings
|
|
from mlagents.trainers.policy.tf_policy import TFPolicy
|
|
from mlagents.trainers import __version__
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class TFSaver(Saver):
|
|
"""
|
|
Saver class for TensorFlow
|
|
"""
|
|
def __init__(
|
|
self,
|
|
policy: TFPolicy,
|
|
trainer_settings: TrainerSettings,
|
|
model_path: str,
|
|
load: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.policy = policy
|
|
self.model_path = model_path
|
|
self.initialize_path = trainer_settings.init_path
|
|
self._keep_checkpoints = trainer_settings.keep_checkpoints
|
|
self.load = load
|
|
|
|
|
|
self.graph = self.policy.graph
|
|
self.sess = self.policy.sess
|
|
with self.graph.as_default():
|
|
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
|
|
|
|
def register(self, module_dict):
|
|
pass
|
|
|
|
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
|
|
"""
|
|
Checkpoints the policy on disk.
|
|
|
|
:param checkpoint_path: filepath to write the checkpoint
|
|
:param brain_name: Brain name of brain to be trained
|
|
"""
|
|
print('save checkpoint_path:', checkpoint_path)
|
|
# Save the TF checkpoint and graph definition
|
|
with self.graph.as_default():
|
|
if self.saver:
|
|
self.saver.save(self.sess, f"{checkpoint_path}.ckpt")
|
|
tf.train.write_graph(
|
|
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
|
|
)
|
|
# also save the policy so we have optimized model files for each checkpoint
|
|
self.export(checkpoint_path, brain_name)
|
|
|
|
def export(self, output_filepath: str, brain_name: str) -> None:
|
|
"""
|
|
Saves the serialized model, given a path and brain name.
|
|
|
|
This method will save the policy graph to the given filepath. The path
|
|
should be provided without an extension as multiple serialized model formats
|
|
may be generated as a result.
|
|
|
|
:param output_filepath: path (without suffix) for the model file(s)
|
|
:param brain_name: Brain name of brain to be trained.
|
|
"""
|
|
print('export output_filepath:', output_filepath)
|
|
export_policy_model(output_filepath, brain_name, self.graph, self.sess)
|
|
|
|
def maybe_load(self):
|
|
# If there is an initialize path, load from that. Else, load from the set model path.
|
|
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
|
|
# e.g., resume from an initialize path.
|
|
reset_steps = not self.load
|
|
if self.initialize_path is not None:
|
|
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
|
|
elif self.load:
|
|
self._load_graph(self.model_path, reset_global_steps=reset_steps)
|
|
else:
|
|
self.policy._initialize_graph()
|
|
|
|
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
|
|
print('load model_path:', model_path)
|
|
with self.graph.as_default():
|
|
logger.info(f"Loading model from {model_path}.")
|
|
ckpt = tf.train.get_checkpoint_state(model_path)
|
|
if ckpt is None:
|
|
raise UnityPolicyException(
|
|
"The model {} could not be loaded. Make "
|
|
"sure you specified the right "
|
|
"--run-id and that the previous run you are loading from had the same "
|
|
"behavior names.".format(model_path)
|
|
)
|
|
try:
|
|
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
|
|
except tf.errors.NotFoundError:
|
|
raise UnityPolicyException(
|
|
"The model {} was found but could not be loaded. Make "
|
|
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
|
|
"and is using the same trainer configuration as the current run.".format(
|
|
model_path
|
|
)
|
|
)
|
|
self._check_model_version(__version__)
|
|
if reset_global_steps:
|
|
self.policy._set_step(0)
|
|
logger.info(
|
|
"Starting training from step 0 and saving to {}.".format(
|
|
self.model_path
|
|
)
|
|
)
|
|
else:
|
|
logger.info(f"Resuming training from step {self.policy.get_current_step()}.")
|
|
|
|
def _check_model_version(self, version: str) -> None:
|
|
"""
|
|
Checks whether the model being loaded was created with the same version of
|
|
ML-Agents, and throw a warning if not so.
|
|
"""
|
|
if self.policy.version_tensors is not None:
|
|
loaded_ver = tuple(
|
|
num.eval(session=self.sess) for num in self.policy.version_tensors
|
|
)
|
|
if loaded_ver != TFPolicy._convert_version_string(version):
|
|
logger.warning(
|
|
f"The model checkpoint you are loading from was saved with ML-Agents version "
|
|
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
|
|
f"version is {version}. Model may not behave properly."
|
|
)
|