Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

135 行
5.5 KiB

from typing import Tuple
from distutils.version import LooseVersion
from mlagents_envs.exception import UnityException
from mlagents_envs.logging_util import get_logger
from mlagents.tf_utils import tf
from mlagents.trainers.saver.saver import Saver
from mlagents.trainers.tf.model_serialization import export_policy_model
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers import __version__
logger = get_logger(__name__)
class TFSaver(Saver):
"""
Saver class for TensorFlow
"""
def __init__(
self,
policy: TFPolicy,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
):
super().__init__()
self.policy = policy
self.model_path = model_path
self.initialize_path = trainer_settings.init_path
self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
self.graph = self.policy.graph
self.sess = self.policy.sess
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
def register(self, module_dict):
pass
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param brain_name: Brain name of brain to be trained
"""
print('save checkpoint_path:', checkpoint_path)
# Save the TF checkpoint and graph definition
with self.graph.as_default():
if self.saver:
self.saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, brain_name)
def export(self, output_filepath: str, brain_name: str) -> None:
"""
Saves the serialized model, given a path and brain name.
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param brain_name: Brain name of brain to be trained.
"""
print('export output_filepath:', output_filepath)
export_policy_model(output_filepath, brain_name, self.graph, self.sess)
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
elif self.load:
self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self.policy._initialize_graph()
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
print('load model_path:', model_path)
with self.graph.as_default():
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
try:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
self._check_model_version(__version__)
if reset_global_steps:
self.policy._set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {self.policy.get_current_step()}.")
def _check_model_version(self, version: str) -> None:
"""
Checks whether the model being loaded was created with the same version of
ML-Agents, and throw a warning if not so.
"""
if self.policy.version_tensors is not None:
loaded_ver = tuple(
num.eval(session=self.sess) for num in self.policy.version_tensors
)
if loaded_ver != TFPolicy._convert_version_string(version):
logger.warning(
f"The model checkpoint you are loading from was saved with ML-Agents version "
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
f"version is {version}. Model may not behave properly."
)