浏览代码

move tf and add torch model serialization

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
6d67f857
共有 8 个文件被更改,包括 317 次插入43 次删除
  1. 4
      ml-agents/mlagents/trainers/policy/policy.py
  2. 1
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 16
      ml-agents/mlagents/trainers/saver/tf_saver.py
  4. 30
      ml-agents/mlagents/trainers/saver/torch_saver.py
  5. 4
      ml-agents/mlagents/trainers/settings.py
  6. 12
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  7. 221
      ml-agents/mlagents/trainers/tf/model_serialization.py
  8. 72
      ml-agents/mlagents/trainers/torch/model_serialization.py

4
ml-agents/mlagents/trainers/policy/policy.py


from mlagents_envs.base_env import DecisionSteps
from mlagents_envs.exception import UnityException
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings, NetworkSettings

self.vis_obs_size = sum(
1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.vis_obs_shape = [
shape for shape in behavior_spec.observation_shapes if len(shape) == 3
][0] if self.vis_obs_size > 0 else None
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}

1
ml-agents/mlagents/trainers/policy/torch_policy.py


import torch
import os
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.behavior_id_utils import get_global_agent_id

16
ml-agents/mlagents/trainers/saver/tf_saver.py


from mlagents_envs.logging_util import get_logger
from mlagents.tf_utils import tf
from mlagents.trainers.saver.saver import Saver
from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents.trainers.tf.model_serialization import export_policy_model
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.tf_policy import TFPolicy

def register(self, module_dict):
pass
def save_checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
:param settings: SerializationSettings for exporting the model.
:param brain_name: Brain name of brain to be trained
"""
print('save checkpoint_path:', checkpoint_path)
# Save the TF checkpoint and graph definition

self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, settings)
self.export(checkpoint_path, brain_name)
def export(self, output_filepath: str, settings: SerializationSettings) -> None:
def export(self, output_filepath: str, brain_name: str) -> None:
Saves the serialized model, given a path and SerializationSettings
Saves the serialized model, given a path and brain name.
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats

:param settings: SerializationSettings for how to save the model.
:param brain_name: Brain name of brain to be trained.
export_policy_model(output_filepath, settings, self.graph, self.sess)
export_policy_model(output_filepath, brain_name, self.graph, self.sess)
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.

30
ml-agents/mlagents/trainers/saver/torch_saver.py


import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.saver.saver import Saver
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.torch.model_serialization import ModelSerializer
logger = get_logger(__name__)

self.initialize_path = trainer_settings.init_path
self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
self.exporter = ModelSerializer(self.policy)
self.modules = {}

def save_checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
:param settings: SerializationSettings for exporting the model.
:param brain_name: Brain name of brain to be trained
"""
print('save checkpoint_path:', checkpoint_path)
if not os.path.exists(self.model_path):

torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
self.export(checkpoint_path, brain_name)
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.

elif self.load:
self._load_model(self.model_path, reset_global_steps=reset_steps)
def export(self, output_filepath: str, settings: SerializationSettings) -> None:
def export(self, output_filepath: str, brain_name: str) -> None:
fake_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])]
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
fake_masks = torch.ones([1] + self.policy.actor_critic.act_size)
# print(fake_vec_obs[0].shape, fake_vis_obs[0].shape, fake_masks.shape)
# fake_memories = torch.zeros([1] + [self.m_size])
output_names = ["action", "action_probs", "is_continuous_control", \
"version_number", "memory_size", "action_output_shape"]
input_names = ["vector_observation", "action_mask"]
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]}
torch.onnx.export(
self.policy.actor_critic,
(fake_vec_obs, fake_vis_obs, fake_masks),
f"{output_filepath}.onnx",
verbose=False,
opset_version=9,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
self.exporter.export_policy_model(output_filepath)
def _load_model(self, load_path: str, reset_global_steps: bool = False) -> None:
model_path = os.path.join(load_path, "checkpoint.pt")

4
ml-agents/mlagents/trainers/settings.py


env_name = ""
device = "cpu"
class SerializationSettings:
convert_to_barracuda = True
convert_to_onnx = True
onnx_opset = 9
@attr.s(auto_attribs=True)
class ExportableSettings:

12
ml-agents/mlagents/trainers/trainer/rl_trainer.py


import abc
import time
import attr
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,

logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
# policy = list(self.policies.values())[0]
# model_path = policy.model_path
settings = SerializationSettings(self.saver.model_path, self.brain_name)
# policy.checkpoint(checkpoint_path, settings)
self.saver.save_checkpoint(checkpoint_path, settings)
self.saver.save_checkpoint(checkpoint_path, self.brain_name)
new_checkpoint = NNCheckpoint(
int(self.step),
f"{checkpoint_path}.nn",

logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
# policy = list(self.policies.values())[0]
settings = SerializationSettings(self.saver.model_path, self.brain_name)
# policy.save(policy.model_path, settings)
self.saver.export(self.saver.model_path, settings)
self.saver.export(self.saver.model_path, self.brain_name)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod

221
ml-agents/mlagents/trainers/tf/model_serialization.py


from distutils.util import strtobool
import os
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion
try:
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
from tf2onnx import optimizer
ONNX_EXPORT_ENABLED = True
except ImportError:
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow
ONNX_EXPORT_ENABLED = False
pass
from mlagents.tf_utils import tf
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.tf import tensorflow_to_barracuda as tf2bc
from mlagents.trainers.settings import SerializationSettings
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"):
# ONNX is only tested on 1.12.0 and later
ONNX_EXPORT_ENABLED = False
logger = get_logger(__name__)
POSSIBLE_INPUT_NODES = frozenset(
[
"action_masks",
"epsilon",
"prev_action",
"recurrent_in",
"sequence_length",
"vector_observation",
]
)
POSSIBLE_OUTPUT_NODES = frozenset(
["action", "action_probs", "recurrent_out", "value_estimate"]
)
MODEL_CONSTANTS = frozenset(
[
"action_output_shape",
"is_continuous_control",
"memory_size",
"version_number",
"trainer_major_version",
"trainer_minor_version",
"trainer_patch_version",
]
)
VISUAL_OBSERVATION_PREFIX = "visual_observation_"
def export_policy_model(
output_filepath: str,
brain_name: str,
graph: tf.Graph,
sess: tf.Session,
) -> None:
"""
Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding.
:param output_filepath: file path to output the model (without file suffix)
:param graph: Tensorflow Graph for the policy
:param sess: Tensorflow session for the policy
"""
frozen_graph_def = _make_frozen_graph(brain_name, graph, sess)
if not os.path.exists(output_filepath):
os.makedirs(output_filepath)
# Save frozen graph
frozen_graph_def_path = output_filepath + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(frozen_graph_def.SerializeToString())
# Convert to barracuda
if SerializationSettings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn")
logger.info(f"Exported {output_filepath}.nn")
# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED:
if SerializationSettings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(brain_name, frozen_graph_def)
onnx_output_path = f"{output_filepath}.onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
# Make conversion errors fatal depending on environment variables (only done during CI)
if _enforce_onnx_conversion():
raise
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
)
else:
if _enforce_onnx_conversion():
raise RuntimeError(
"ONNX conversion enforced, but couldn't import dependencies."
)
def _make_frozen_graph(
brain_name: str, graph: tf.Graph, sess: tf.Session
) -> tf.GraphDef:
with graph.as_default():
target_nodes = ",".join(_process_graph(brain_name, graph))
graph_def = graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, target_nodes.replace(" ", "").split(",")
)
return output_graph_def
def convert_frozen_to_onnx(
brain_name: str, frozen_graph_def: tf.GraphDef
) -> Any:
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
inputs = _get_input_node_names(frozen_graph_def)
outputs = _get_output_node_names(frozen_graph_def)
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")
frozen_graph_def = tf_optimize(
inputs, outputs, frozen_graph_def, fold_constant=True
)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(frozen_graph_def, name="")
with tf.Session(graph=tf_graph):
g = process_tf_graph(
tf_graph,
input_names=inputs,
output_names=outputs,
opset=SerializationSettings.onnx_opset,
)
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(brain_name)
return model_proto
def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of input node names from the graph.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
input_names = node_names & POSSIBLE_INPUT_NODES
# Check visual inputs sequentially, and exit as soon as we don't find one
vis_index = 0
while True:
vis_node_name = f"{VISUAL_OBSERVATION_PREFIX}{vis_index}"
if vis_node_name in node_names:
input_names.add(vis_node_name)
else:
break
vis_index += 1
# Append the port
return [f"{n}:0" for n in input_names]
def _get_output_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of output node names from the graph.
Also include constants, so that they will be readable by the
onnx importer.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS)
# Append the port
return [f"{n}:0" for n in output_names]
def _get_frozen_graph_node_names(frozen_graph_def: Any) -> Set[str]:
"""
Get all the node names from the graph.
"""
names = set()
for node in frozen_graph_def.node:
names.add(node.name)
return names
def _process_graph(brain_name: str, graph: tf.Graph) -> List[str]:
"""
Gets the list of the output nodes present in the graph for inference
:return: list of node names
"""
all_nodes = [x.name for x in graph.as_graph_def().node]
nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS]
logger.info("List of nodes to export for brain :" + brain_name)
for n in nodes:
logger.info("\t" + n)
return nodes
def _enforce_onnx_conversion() -> bool:
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION"
if env_var_name not in os.environ:
return False
val = os.environ[env_var_name]
try:
# This handles e.g. "false" converting reasonably to False
return strtobool(val)
except Exception:
return False

72
ml-agents/mlagents/trainers/torch/model_serialization.py


import os
from typing import Any, List, Set, NamedTuple
import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.settings import SerializationSettings
from IPython import embed
logger = get_logger(__name__)
POSSIBLE_INPUT_NODES = frozenset(
[
"action_masks",
"prev_action",
"recurrent_in",
"sequence_length",
"vector_observation",
]
)
POSSIBLE_OUTPUT_NODES = frozenset(
["action", "action_probs", "recurrent_out", "value_estimate"]
)
class ModelSerializer:
def __init__(self, policy):
self.policy = policy
dummy_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])]
dummy_vis_obs = [torch.zeros([1] + self.policy.vis_obs_shape)] \
if self.policy.vis_obs_size > 0 else []
dummy_masks = [torch.ones([1] + self.policy.actor_critic.act_size)]
dummy_memories = [torch.zeros([1] + [self.policy.m_size])]
dummy_sequence_length = [torch.tensor([self.policy.sequence_length])]
self.input_names = ["vector_observation", "visual_observation", \
"action_mask", "memories", "sequence_length"]
self.output_names = ["action", "action_probs", "version_number", \
"memory_size", "is_continuous_control", "action_output_shape"]
self.dynamic_axes = {"vector_observation": [0], "visual_observation": [0], \
"action_mask": [0], "memories": [0], "action": [0],"action_probs": [0]}
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, \
dummy_masks, dummy_memories, dummy_sequence_length)
def export_policy_model(self, output_filepath: str) -> None:
"""
Exports a Torch model for a Policy to .onnx format for Unity embedding.
:param output_filepath: file path to output the model (without file suffix)
:param brain_name: Brain name of brain to be trained
"""
if not os.path.exists(output_filepath):
os.makedirs(output_filepath)
onnx_output_path = f"{output_filepath}.onnx"
logger.info(f"Converting to {onnx_output_path}")
torch.onnx.export(
self.policy.actor_critic,
self.dummy_input,
onnx_output_path,
verbose=True,
opset_version=SerializationSettings.onnx_opset,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
)
正在加载...
取消
保存