Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

165 行
6.9 KiB

import os
import shutil
from typing import Optional, Union, cast
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.tf_utils import tf
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.tf.model_serialization import export_policy_model
from mlagents.trainers.settings import TrainerSettings, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers import __version__
logger = get_logger(__name__)
class TFSaver(BaseSaver):
"""
Saver class for TensorFlow
"""
def __init__(
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False
):
super().__init__()
self.model_path = model_path
self.initialize_path = trainer_settings.init_path
self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
# Currently only support saving one policy. This is the one to be saved.
self.policy: Optional[TFPolicy] = None
self.graph = None
self.sess = None
self.tf_saver = None
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None:
if isinstance(module, TFPolicy):
self._register_policy(module)
elif isinstance(module, TFOptimizer):
self._register_optimizer(module)
else:
raise UnityPolicyException(
"Registering Object of unsupported type {} to Saver ".format(
type(module)
)
)
def _register_policy(self, policy: TFPolicy) -> None:
if self.policy is None:
self.policy = policy
self.graph = self.policy.graph
self.sess = self.policy.sess
with self.policy.graph.as_default():
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
def save_checkpoint(self, brain_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
# Save the TF checkpoint and graph definition
if self.graph:
with self.graph.as_default():
if self.tf_saver:
self.tf_saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, brain_name)
return checkpoint_path
def export(self, output_filepath: str, brain_name: str) -> None:
export_policy_model(
self.model_path, output_filepath, brain_name, self.graph, self.sess
)
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
if policy is None:
policy = self.policy
policy = cast(TFPolicy, policy)
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(
policy, self.initialize_path, reset_global_steps=reset_steps
)
elif self.load:
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps)
else:
policy.initialize()
def _load_graph(
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False
) -> None:
with policy.graph.as_default():
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
if self.tf_saver:
try:
self.tf_saver.restore(policy.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
self._check_model_version(__version__)
if reset_global_steps:
policy.set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {policy.get_current_step()}.")
def _check_model_version(self, version: str) -> None:
"""
Checks whether the model being loaded was created with the same version of
ML-Agents, and throw a warning if not so.
"""
if self.policy is not None and self.policy.version_tensors is not None:
loaded_ver = tuple(
num.eval(session=self.sess) for num in self.policy.version_tensors
)
if loaded_ver != TFPolicy._convert_version_string(version):
logger.warning(
f"The model checkpoint you are loading from was saved with ML-Agents version "
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
f"version is {version}. Model may not behave properly."
)
def copy_final_model(self, source_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
final_model_name = os.path.splitext(source_nn_path)[0]
if SerializationSettings.convert_to_barracuda:
source_path = f"{final_model_name}.nn"
destination_path = f"{self.model_path}.nn"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
if SerializationSettings.convert_to_onnx:
try:
source_path = f"{final_model_name}.onnx"
destination_path = f"{self.model_path}.onnx"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
except OSError:
pass