浏览代码

add Saver class (only TF working)

/develop/add-fire/ckpt-2
Ruo-Ping Dong 5 年前
当前提交
6feec58a
共有 12 个文件被更改,包括 338 次插入160 次删除
  1. 11
      ml-agents/mlagents/trainers/policy/policy.py
  2. 94
      ml-agents/mlagents/trainers/policy/tf_policy.py
  3. 58
      ml-agents/mlagents/trainers/policy/torch_policy.py
  4. 2
      ml-agents/mlagents/trainers/ppo/optimizer_tf.py
  5. 3
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  6. 7
      ml-agents/mlagents/trainers/ppo/trainer.py
  7. 8
      ml-agents/mlagents/trainers/torch/networks.py
  8. 54
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  9. 28
      ml-agents/mlagents/trainers/saver/saver.py
  10. 135
      ml-agents/mlagents/trainers/saver/tf_saver.py
  11. 98
      ml-agents/mlagents/trainers/saver/torch_saver.py

11
ml-agents/mlagents/trainers/policy/policy.py


self.vis_obs_size = sum(
1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.model_path = model_path
self.initialize_path = self.trainer_settings.init_path
self._keep_checkpoints = self.trainer_settings.keep_checkpoints
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}

@abstractmethod
def get_current_step(self):
pass
@abstractmethod
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
pass
@abstractmethod
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
pass
@abstractmethod

94
ml-agents/mlagents/trainers/policy/tf_policy.py


from mlagents_envs.timers import timed
from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents.tf_utils import tf
from mlagents import tf_utils
from mlagents_envs.exception import UnityException

self.sess = tf.Session(
config=tf_utils.generate_session_config(), graph=self.graph
)
self.saver: Optional[tf.Operation] = None
self._initialize_tensorflow_references()
self.grads = None
self.update_batch: Optional[tf.Operation] = None

ver = LooseVersion(version_string)
return tuple(map(int, ver.version[0:3]))
def _check_model_version(self, version: str) -> None:
"""
Checks whether the model being loaded was created with the same version of
ML-Agents, and throw a warning if not so.
"""
if self.version_tensors is not None:
loaded_ver = tuple(
num.eval(session=self.sess) for num in self.version_tensors
)
if loaded_ver != TFPolicy._convert_version_string(version):
logger.warning(
f"The model checkpoint you are loading from was saved with ML-Agents version "
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
f"version is {version}. Model may not behave properly."
)
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
try:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
self._check_model_version(__version__)
if reset_global_steps:
self._set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {self.get_current_step()}.")
def initialize_or_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
elif self.load:
self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self._initialize_graph()
def get_weights(self):
with self.graph.as_default():
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

:return:list of update var names
"""
return list(self.update_dict.keys())
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param settings: SerializationSettings for exporting the model.
"""
# Save the TF checkpoint and graph definition
with self.graph.as_default():
if self.saver:
self.saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.save(checkpoint_path, settings)
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
"""
Saves the serialized model, given a path and SerializationSettings
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param settings: SerializationSettings for how to save the model.
"""
export_policy_model(output_filepath, settings, self.graph, self.sess)
def update_normalization(self, vector_obs: np.ndarray) -> None:
"""

58
ml-agents/mlagents/trainers/policy/torch_policy.py


import torch
import os
from torch import onnx
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.behavior_id_utils import get_global_agent_id

from mlagents.trainers.settings import TrainerSettings, TestingConfiguration
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.networks import ActorCritic
from mlagents.trainers.torch.networks import ActorCritic, GlobalSteps
EPSILON = 1e-7 # Small value to avoid divide by zero

reparameterize,
condition_sigma_on_obs,
)
self.global_step = 0
self.global_step = GlobalSteps() # could be much simpler if TorchPolicy is nn.Module
self.grads = None
if TestingConfiguration.device != "cpu":
torch.set_default_tensor_type(torch.cuda.FloatTensor)

agent_ids=list(decision_requests.agent_id),
)
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param settings: SerializationSettings for exporting the model.
"""
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
torch.save(self.actor_critic.state_dict(), f"{checkpoint_path}.pt")
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
self.export_model(self.global_step)
def load_model(self, step=0): # TODO: this doesn't work
load_path = self.model_path + "/model-" + str(step) + ".pt"
self.actor_critic.load_state_dict(torch.load(load_path))
def export_model(self, step=0):
fake_vec_obs = [torch.zeros([1] + [self.vec_obs_size])]
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
fake_masks = torch.ones([1] + self.actor_critic.act_size)
# fake_memories = torch.zeros([1] + [self.m_size])
export_path = "./model-" + str(step) + ".onnx"
output_names = ["action", "action_probs"]
input_names = ["vector_observation", "action_mask"]
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]}
onnx.export(
self.actor_critic,
(fake_vec_obs, fake_vis_obs, fake_masks),
export_path,
verbose=True,
opset_version=12,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
@property
def use_vis_obs(self):
return self.vis_obs_size > 0

Gets current model step.
:return: current model step.
"""
step = self.global_step
step = self.global_step.get_step()
def _set_step(self, step: int) -> int:
"""
Sets current model step to step without creating additional ops.
:param step: Step to set the current model step to.
:return: The step the model was set to.
"""
self.global_step.set_step(step)
self.global_step += n_steps
self.global_step.increment(n_steps)
return self.get_current_step()
def load_weights(self, values: List[np.ndarray]) -> None:

def get_weights(self) -> List[np.ndarray]:
return []
def get_modules(self):
return {'Policy': self.actor_critic, 'global_step': self.global_step}

2
ml-agents/mlagents/trainers/ppo/optimizer_tf.py


}
)
self.policy.initialize_or_load()
def _create_cc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:

3
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


}
return update_stats
def get_modules(self):
return {'Optimizer': self.optimizer}

7
ml-agents/mlagents/trainers/ppo/trainer.py


) # type: ignore
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
if self.saver is None:
self.saver = self.create_saver(policy=policy)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.maybe_load()
# Needed to resume loads properly
self.step = policy.get_current_step()

8
ml-agents/mlagents/trainers/torch/networks.py


class GlobalSteps(nn.Module):
def __init__(self):
super().__init__()
self.global_step = torch.Tensor([0])
self.global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False)
def set_step(self, value):
self.global_step[:] = value
def get_step(self):
return int(self.global_step.item())
class LearningRate(nn.Module):

54
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.settings import TestingConfiguration
from mlagents.trainers.stats import StatsPropertyType
from mlagents.trainers.saver.saver import Saver
from mlagents.trainers.saver.torch_saver import TorchSaver
from mlagents.trainers.saver.tf_saver import TFSaver
RewardSignalResults = Dict[str, RewardSignalResult]

self.trainer_settings.max_steps = TestingConfiguration.max_steps
self._next_save_step = 0
self._next_summary_step = 0
self.saver = None
def end_episode(self) -> None:
"""

Create a Policy object that uses the TensorFlow backend.
"""
pass
def create_saver(self, policy: Policy) -> Saver:
if self.framework == "torch":
return self.create_torch_saver(policy)
else:
return self.create_tf_saver(policy)
def create_torch_saver(self, policy: TorchPolicy) -> TorchSaver:
"""
Create a Saver object that uses the PyTorch backend.
"""
saver = TorchSaver(
policy,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
)
return saver
def create_tf_saver(self, policy: TFPolicy) -> TFSaver:
"""
Create a Saver object that uses the TensorFlow backend.
"""
saver = TFSaver(
policy,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
)
return saver
def _policy_mean_reward(self) -> Optional[float]:
""" Returns the mean episode reward for the current policy. """

logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
model_path = policy.model_path
settings = SerializationSettings(model_path, self.brain_name)
checkpoint_path = os.path.join(model_path, f"{self.brain_name}-{self.step}")
policy.checkpoint(checkpoint_path, settings)
# policy = list(self.policies.values())[0]
# model_path = policy.model_path
settings = SerializationSettings(self.saver.model_path, self.brain_name)
checkpoint_path = os.path.join(self.saver.model_path, f"{self.brain_name}-{self.step}")
# policy.checkpoint(checkpoint_path, settings)
self.saver.save_checkpoint(checkpoint_path, settings)
new_checkpoint = NNCheckpoint(
int(self.step),
f"{checkpoint_path}.nn",

logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
settings = SerializationSettings(policy.model_path, self.brain_name)
# policy = list(self.policies.values())[0]
settings = SerializationSettings(self.saver.model_path, self.brain_name)
model_checkpoint, file_path=f"{policy.model_path}.nn"
model_checkpoint, file_path=f"{self.saver.model_path}.nn"
policy.save(policy.model_path, settings)
# policy.save(policy.model_path, settings)
self.saver.export(self.saver.model_path, settings)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod

28
ml-agents/mlagents/trainers/saver/saver.py


# # Unity ML-Agents Toolkit
import abc
class Saver(abc.ABC):
"""This class is the base class for the Saver"""
def __init__(self):
"""
TBA
"""
pass
@abc.abstractmethod
def register(self):
pass
@abc.abstractmethod
def save_checkpoint(self):
pass
@abc.abstractmethod
def maybe_load(self):
pass
@abc.abstractmethod
def export(self):
pass

135
ml-agents/mlagents/trainers/saver/tf_saver.py


from typing import Tuple
from distutils.version import LooseVersion
from mlagents_envs.exception import UnityException
from mlagents_envs.logging_util import get_logger
from mlagents.tf_utils import tf
from mlagents.trainers.saver.saver import Saver
from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers import __version__
logger = get_logger(__name__)
class TFSaver(Saver):
"""
Saver class for TensorFlow
"""
def __init__(
self,
policy: TFPolicy,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
):
super().__init__()
self.policy = policy
self.model_path = model_path
self.initialize_path = trainer_settings.init_path
self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
self.graph = self.policy.graph
self.sess = self.policy.sess
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
def register(self, module_dict):
pass
def save_checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param settings: SerializationSettings for exporting the model.
"""
print('save checkpoint_path:', checkpoint_path)
# Save the TF checkpoint and graph definition
with self.graph.as_default():
if self.saver:
self.saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, settings)
def export(self, output_filepath: str, settings: SerializationSettings) -> None:
"""
Saves the serialized model, given a path and SerializationSettings
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param settings: SerializationSettings for how to save the model.
"""
print('export output_filepath:', output_filepath)
export_policy_model(output_filepath, settings, self.graph, self.sess)
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
elif self.load:
self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self.policy._initialize_graph()
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
print('load model_path:', model_path)
with self.graph.as_default():
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
try:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
self._check_model_version(__version__)
if reset_global_steps:
self.policy._set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {self.policy.get_current_step()}.")
def _check_model_version(self, version: str) -> None:
"""
Checks whether the model being loaded was created with the same version of
ML-Agents, and throw a warning if not so.
"""
if self.policy.version_tensors is not None:
loaded_ver = tuple(
num.eval(session=self.sess) for num in self.policy.version_tensors
)
if loaded_ver != TFPolicy._convert_version_string(version):
logger.warning(
f"The model checkpoint you are loading from was saved with ML-Agents version "
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
f"version is {version}. Model may not behave properly."
)

98
ml-agents/mlagents/trainers/saver/torch_saver.py


import os
import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.saver.saver import Saver
from mlagents.model_serialization import SerializationSettings
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.torch_policy import TorchPolicy
logger = get_logger(__name__)
class TorchSaver(Saver):
"""
Saver class for PyTorch
"""
def __init__(
self,
policy: TorchPolicy,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
):
super().__init__()
self.policy = policy
self.model_path = model_path
self.initialize_path = trainer_settings.init_path
self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
self.modules = {}
def register(self, module):
self.modules.update(module.get_modules())
def save_checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param settings: SerializationSettings for exporting the model.
"""
print('save checkpoint_path:', checkpoint_path)
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
state_dict = {name: module.state_dict() for name, module in self.modules.items()}
torch.save(state_dict, f"{checkpoint_path}.pt")
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_model(self.initialize_path, reset_global_steps=reset_steps)
elif self.load:
self._load_model(self.model_path, reset_global_steps=reset_steps)
def export(self, output_filepath: str, settings: SerializationSettings) -> None:
print('export output_filepath:', output_filepath)
fake_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])]
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
fake_masks = torch.ones([1] + self.policy.actor_critic.act_size)
# print(fake_vec_obs[0].shape, fake_vis_obs[0].shape, fake_masks.shape)
# fake_memories = torch.zeros([1] + [self.m_size])
output_names = ["action", "action_probs", "is_continuous_control", \
"version_number", "memory_size", "action_output_shape"]
input_names = ["vector_observation", "action_mask"]
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]}
torch.onnx.export(
self.policy.actor_critic,
(fake_vec_obs, fake_vis_obs, fake_masks),
f"{output_filepath}.onnx",
verbose=False,
opset_version=9,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
def _load_model(self, load_path: str, reset_global_steps: bool = False) -> None:
model_path = os.path.join(load_path, "checkpoint.pt")
print('load model_path:', model_path)
saved_state_dict = torch.load(model_path)
for name, state_dict in saved_state_dict.items():
self.modules[name].load_state_dict(state_dict)
if reset_global_steps:
self.policy._set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {self.policy.get_current_step()}.")
正在加载...
取消
保存