Ruo-Ping Dong
5 年前
当前提交
6feec58a
共有 12 个文件被更改,包括 338 次插入 和 160 次删除
-
11ml-agents/mlagents/trainers/policy/policy.py
-
94ml-agents/mlagents/trainers/policy/tf_policy.py
-
58ml-agents/mlagents/trainers/policy/torch_policy.py
-
2ml-agents/mlagents/trainers/ppo/optimizer_tf.py
-
3ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
7ml-agents/mlagents/trainers/ppo/trainer.py
-
8ml-agents/mlagents/trainers/torch/networks.py
-
54ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
28ml-agents/mlagents/trainers/saver/saver.py
-
135ml-agents/mlagents/trainers/saver/tf_saver.py
-
98ml-agents/mlagents/trainers/saver/torch_saver.py
|
|||
# # Unity ML-Agents Toolkit |
|||
import abc |
|||
|
|||
|
|||
class Saver(abc.ABC): |
|||
"""This class is the base class for the Saver""" |
|||
|
|||
def __init__(self): |
|||
""" |
|||
TBA |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def register(self): |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def save_checkpoint(self): |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def maybe_load(self): |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def export(self): |
|||
pass |
|
|||
from typing import Tuple |
|||
from distutils.version import LooseVersion |
|||
from mlagents_envs.exception import UnityException |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.tf_utils import tf |
|||
from mlagents.trainers.saver.saver import Saver |
|||
from mlagents.model_serialization import SerializationSettings, export_policy_model |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers import __version__ |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TFSaver(Saver): |
|||
""" |
|||
Saver class for TensorFlow |
|||
""" |
|||
def __init__( |
|||
self, |
|||
policy: TFPolicy, |
|||
trainer_settings: TrainerSettings, |
|||
model_path: str, |
|||
load: bool = False, |
|||
): |
|||
super().__init__() |
|||
self.policy = policy |
|||
self.model_path = model_path |
|||
self.initialize_path = trainer_settings.init_path |
|||
self._keep_checkpoints = trainer_settings.keep_checkpoints |
|||
self.load = load |
|||
|
|||
|
|||
self.graph = self.policy.graph |
|||
self.sess = self.policy.sess |
|||
with self.graph.as_default(): |
|||
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints) |
|||
|
|||
def register(self, module_dict): |
|||
pass |
|||
|
|||
def save_checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: |
|||
""" |
|||
Checkpoints the policy on disk. |
|||
|
|||
:param checkpoint_path: filepath to write the checkpoint |
|||
:param settings: SerializationSettings for exporting the model. |
|||
""" |
|||
print('save checkpoint_path:', checkpoint_path) |
|||
# Save the TF checkpoint and graph definition |
|||
with self.graph.as_default(): |
|||
if self.saver: |
|||
self.saver.save(self.sess, f"{checkpoint_path}.ckpt") |
|||
tf.train.write_graph( |
|||
self.graph, self.model_path, "raw_graph_def.pb", as_text=False |
|||
) |
|||
# also save the policy so we have optimized model files for each checkpoint |
|||
self.export(checkpoint_path, settings) |
|||
|
|||
def export(self, output_filepath: str, settings: SerializationSettings) -> None: |
|||
""" |
|||
Saves the serialized model, given a path and SerializationSettings |
|||
|
|||
This method will save the policy graph to the given filepath. The path |
|||
should be provided without an extension as multiple serialized model formats |
|||
may be generated as a result. |
|||
|
|||
:param output_filepath: path (without suffix) for the model file(s) |
|||
:param settings: SerializationSettings for how to save the model. |
|||
""" |
|||
print('export output_filepath:', output_filepath) |
|||
export_policy_model(output_filepath, settings, self.graph, self.sess) |
|||
|
|||
def maybe_load(self): |
|||
# If there is an initialize path, load from that. Else, load from the set model path. |
|||
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to, |
|||
# e.g., resume from an initialize path. |
|||
reset_steps = not self.load |
|||
if self.initialize_path is not None: |
|||
self._load_graph(self.initialize_path, reset_global_steps=reset_steps) |
|||
elif self.load: |
|||
self._load_graph(self.model_path, reset_global_steps=reset_steps) |
|||
else: |
|||
self.policy._initialize_graph() |
|||
|
|||
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None: |
|||
print('load model_path:', model_path) |
|||
with self.graph.as_default(): |
|||
logger.info(f"Loading model from {model_path}.") |
|||
ckpt = tf.train.get_checkpoint_state(model_path) |
|||
if ckpt is None: |
|||
raise UnityPolicyException( |
|||
"The model {} could not be loaded. Make " |
|||
"sure you specified the right " |
|||
"--run-id and that the previous run you are loading from had the same " |
|||
"behavior names.".format(model_path) |
|||
) |
|||
try: |
|||
self.saver.restore(self.sess, ckpt.model_checkpoint_path) |
|||
except tf.errors.NotFoundError: |
|||
raise UnityPolicyException( |
|||
"The model {} was found but could not be loaded. Make " |
|||
"sure the model is from the same version of ML-Agents, has the same behavior parameters, " |
|||
"and is using the same trainer configuration as the current run.".format( |
|||
model_path |
|||
) |
|||
) |
|||
self._check_model_version(__version__) |
|||
if reset_global_steps: |
|||
self.policy._set_step(0) |
|||
logger.info( |
|||
"Starting training from step 0 and saving to {}.".format( |
|||
self.model_path |
|||
) |
|||
) |
|||
else: |
|||
logger.info(f"Resuming training from step {self.policy.get_current_step()}.") |
|||
|
|||
def _check_model_version(self, version: str) -> None: |
|||
""" |
|||
Checks whether the model being loaded was created with the same version of |
|||
ML-Agents, and throw a warning if not so. |
|||
""" |
|||
if self.policy.version_tensors is not None: |
|||
loaded_ver = tuple( |
|||
num.eval(session=self.sess) for num in self.policy.version_tensors |
|||
) |
|||
if loaded_ver != TFPolicy._convert_version_string(version): |
|||
logger.warning( |
|||
f"The model checkpoint you are loading from was saved with ML-Agents version " |
|||
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents" |
|||
f"version is {version}. Model may not behave properly." |
|||
) |
|
|||
import os |
|||
|
|||
import torch |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers.saver.saver import Saver |
|||
from mlagents.model_serialization import SerializationSettings |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TorchSaver(Saver): |
|||
""" |
|||
Saver class for PyTorch |
|||
""" |
|||
def __init__( |
|||
self, |
|||
policy: TorchPolicy, |
|||
trainer_settings: TrainerSettings, |
|||
model_path: str, |
|||
load: bool = False, |
|||
): |
|||
super().__init__() |
|||
self.policy = policy |
|||
self.model_path = model_path |
|||
self.initialize_path = trainer_settings.init_path |
|||
self._keep_checkpoints = trainer_settings.keep_checkpoints |
|||
self.load = load |
|||
|
|||
self.modules = {} |
|||
|
|||
def register(self, module): |
|||
self.modules.update(module.get_modules()) |
|||
|
|||
def save_checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: |
|||
""" |
|||
Checkpoints the policy on disk. |
|||
|
|||
:param checkpoint_path: filepath to write the checkpoint |
|||
:param settings: SerializationSettings for exporting the model. |
|||
""" |
|||
print('save checkpoint_path:', checkpoint_path) |
|||
if not os.path.exists(self.model_path): |
|||
os.makedirs(self.model_path) |
|||
state_dict = {name: module.state_dict() for name, module in self.modules.items()} |
|||
torch.save(state_dict, f"{checkpoint_path}.pt") |
|||
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt")) |
|||
|
|||
def maybe_load(self): |
|||
# If there is an initialize path, load from that. Else, load from the set model path. |
|||
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to, |
|||
# e.g., resume from an initialize path. |
|||
reset_steps = not self.load |
|||
if self.initialize_path is not None: |
|||
self._load_model(self.initialize_path, reset_global_steps=reset_steps) |
|||
elif self.load: |
|||
self._load_model(self.model_path, reset_global_steps=reset_steps) |
|||
|
|||
def export(self, output_filepath: str, settings: SerializationSettings) -> None: |
|||
print('export output_filepath:', output_filepath) |
|||
fake_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])] |
|||
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])] |
|||
fake_masks = torch.ones([1] + self.policy.actor_critic.act_size) |
|||
# print(fake_vec_obs[0].shape, fake_vis_obs[0].shape, fake_masks.shape) |
|||
# fake_memories = torch.zeros([1] + [self.m_size]) |
|||
output_names = ["action", "action_probs", "is_continuous_control", \ |
|||
"version_number", "memory_size", "action_output_shape"] |
|||
input_names = ["vector_observation", "action_mask"] |
|||
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]} |
|||
torch.onnx.export( |
|||
self.policy.actor_critic, |
|||
(fake_vec_obs, fake_vis_obs, fake_masks), |
|||
f"{output_filepath}.onnx", |
|||
verbose=False, |
|||
opset_version=9, |
|||
input_names=input_names, |
|||
output_names=output_names, |
|||
dynamic_axes=dynamic_axes, |
|||
) |
|||
|
|||
def _load_model(self, load_path: str, reset_global_steps: bool = False) -> None: |
|||
model_path = os.path.join(load_path, "checkpoint.pt") |
|||
print('load model_path:', model_path) |
|||
saved_state_dict = torch.load(model_path) |
|||
for name, state_dict in saved_state_dict.items(): |
|||
self.modules[name].load_state_dict(state_dict) |
|||
if reset_global_steps: |
|||
self.policy._set_step(0) |
|||
logger.info( |
|||
"Starting training from step 0 and saving to {}.".format( |
|||
self.model_path |
|||
) |
|||
) |
|||
else: |
|||
logger.info(f"Resuming training from step {self.policy.get_current_step()}.") |
撰写
预览
正在加载...
取消
保存
Reference in new issue