浏览代码

Add Saver Class to handle all save/load/checkpoint/export work (#4323)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
25dc8c3d
共有 27 个文件被更改,包括 650 次插入525 次删除
  1. 8
      ml-agents-envs/mlagents_envs/exception.py
  2. 2
      ml-agents/mlagents/trainers/ghost/trainer.py
  3. 15
      ml-agents/mlagents/trainers/policy/policy.py
  4. 112
      ml-agents/mlagents/trainers/policy/tf_policy.py
  5. 2
      ml-agents/mlagents/trainers/ppo/optimizer.py
  6. 10
      ml-agents/mlagents/trainers/ppo/trainer.py
  7. 2
      ml-agents/mlagents/trainers/sac/optimizer.py
  8. 10
      ml-agents/mlagents/trainers/sac/trainer.py
  9. 6
      ml-agents/mlagents/trainers/settings.py
  10. 27
      ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
  11. 10
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  12. 62
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  13. 8
      ml-agents/mlagents/trainers/tests/test_ppo.py
  14. 1
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  15. 21
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  16. 6
      ml-agents/mlagents/trainers/tests/test_sac.py
  17. 8
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  18. 20
      ml-agents/mlagents/trainers/tests/test_tf_policy.py
  19. 26
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  20. 2
      ml-agents/mlagents/trainers/trainer/trainer.py
  21. 113
      ml-agents/mlagents/trainers/tests/test_saver.py
  22. 221
      ml-agents/mlagents/trainers/tf/model_serialization.py
  23. 0
      ml-agents/mlagents/trainers/saver/__init__.py
  24. 66
      ml-agents/mlagents/trainers/saver/saver.py
  25. 170
      ml-agents/mlagents/trainers/saver/tf_saver.py
  26. 247
      ml-agents/mlagents/model_serialization.py

8
ml-agents-envs/mlagents_envs/exception.py


def __init__(self, worker_id):
message = self.MESSAGE_TEMPLATE.format(str(worker_id))
super().__init__(message)
class UnityPolicyException(UnityException):
"""
Related to errors with the Trainer.
"""
pass

2
ml-agents/mlagents/trainers/ghost/trainer.py


"""
policy = self.trainer.create_policy(parsed_behavior_id, behavior_spec)
policy.create_tf_graph()
policy.initialize_or_load()
self.trainer.saver.initialize_or_load(policy)
policy.init_load_weights()
team_id = parsed_behavior_id.team_id
self.controller.subscribe_team_id(team_id, self)

15
ml-agents/mlagents/trainers/policy/policy.py


from mlagents_envs.base_env import DecisionSteps
from mlagents_envs.exception import UnityException
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings, NetworkSettings

seed: int,
behavior_spec: BehaviorSpec,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,

self.vis_obs_size = sum(
1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.model_path = model_path
self.initialize_path = self.trainer_settings.init_path
self._keep_checkpoints = self.trainer_settings.keep_checkpoints
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}

self.load = load
self.h_size = self.network_settings.hidden_units
num_layers = self.network_settings.num_layers
if num_layers < 1:

@abstractmethod
def get_current_step(self):
pass
@abstractmethod
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
pass
@abstractmethod
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
pass
@abstractmethod

112
ml-agents/mlagents/trainers/policy/tf_policy.py


from mlagents_envs.timers import timed
from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents.tf_utils import tf
from mlagents import tf_utils
from mlagents_envs.exception import UnityException

seed: int,
behavior_spec: BehaviorSpec,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,

:param seed: Random seed to use for TensorFlow.
:param brain: The corresponding Brain for this policy.
:param trainer_settings: The trainer parameters.
:param model_path: Where to load/save the model.
:param load: If True, load model from model_path. Otherwise, create new model.
model_path,
load,
tanh_squash,
reparameterize,
condition_sigma_on_obs,

self.sess = tf.Session(
config=tf_utils.generate_session_config(), graph=self.graph
)
self.saver: Optional[tf.Operation] = None
self._initialize_tensorflow_references()
self.grads = None
self.update_batch: Optional[tf.Operation] = None

# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
# it will re-load the full graph
self._initialize_graph()
self.initialize()
def _create_encoder(
self,

ver = LooseVersion(version_string)
return tuple(map(int, ver.version[0:3]))
def _check_model_version(self, version: str) -> None:
"""
Checks whether the model being loaded was created with the same version of
ML-Agents, and throw a warning if not so.
"""
if self.version_tensors is not None:
loaded_ver = tuple(
num.eval(session=self.sess) for num in self.version_tensors
)
if loaded_ver != TFPolicy._convert_version_string(version):
logger.warning(
f"The model checkpoint you are loading from was saved with ML-Agents version "
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
f"version is {version}. Model may not behave properly."
)
def _initialize_graph(self):
def initialize(self):
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
try:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
self._check_model_version(__version__)
if reset_global_steps:
self._set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {self.get_current_step()}.")
def initialize_or_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
elif self.load:
self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self._initialize_graph()
# broadcast initial weights from worker-0
TFPolicy.broadcast_global_variables(0)
def get_weights(self):
with self.graph.as_default():
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

step = self.sess.run(self.global_step)
return step
def _set_step(self, step: int) -> int:
def set_step(self, step: int) -> int:
"""
Sets current model step to step without creating additional ops.
:param step: Step to set the current model step to.

:return:list of update var names
"""
return list(self.update_dict.keys())
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param settings: SerializationSettings for exporting the model.
"""
# Save the TF checkpoint and graph definition
with self.graph.as_default():
if self.saver:
self.saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.save(checkpoint_path, settings)
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
"""
Saves the serialized model, given a path and SerializationSettings
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param settings: SerializationSettings for how to save the model.
"""
# save model if there is only one worker or
# only on worker-0 if there are multiple workers
if self.rank is not None and self.rank != 0:
return
export_policy_model(output_filepath, settings, self.graph, self.sess)
def update_normalization(self, vector_obs: np.ndarray) -> None:
"""

2
ml-agents/mlagents/trainers/ppo/optimizer.py


}
)
self.policy.initialize_or_load()
def _create_cc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:

10
ml-agents/mlagents/trainers/ppo/trainer.py


:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super().__init__(
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
brain_name, trainer_settings, training, load, artifact_path, reward_buff_cap
self.load = load
self.seed = seed
self.policy: Policy = None # type: ignore

self.seed,
behavior_spec,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
condition_sigma_on_obs=False, # Faster training for PPO
create_tf_graph=False, # We will create the TF graph in the Optimizer
)

self.optimizer = self.create_ppo_optimizer()
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.initialize_or_load()
# Needed to resume loads properly
self.step = policy.get_current_step()

2
ml-agents/mlagents/trainers/sac/optimizer.py


[self.policy.update_normalization_op, target_update_norm]
)
self.policy.initialize_or_load()
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",

10
ml-agents/mlagents/trainers/sac/trainer.py


:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super().__init__(
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
brain_name, trainer_settings, training, load, artifact_path, reward_buff_cap
self.load = load
self.seed = seed
self.policy: Policy = None # type: ignore
self.optimizer: SACOptimizer = None # type: ignore

self.seed,
behavior_spec,
self.trainer_settings,
self.artifact_path,
self.load,
tanh_squash=True,
reparameterize=True,
create_tf_graph=False,

self.optimizer = self.create_sac_optimizer()
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.initialize_or_load()
# Needed to resume loads properly
self.step = policy.get_current_step()
# Assume steps were updated at the correct ratio before

6
ml-agents/mlagents/trainers/settings.py


return {key: cattr.unstructure(val) for key, val in d.items()}
class SerializationSettings:
convert_to_barracuda = True
convert_to_onnx = True
onnx_opset = 9
@attr.s(auto_attribs=True)
class ExportableSettings:
def as_dict(self):

27
ml-agents/mlagents/trainers/tests/test_barracuda_converter.py


import os
import tempfile
import pytest
from mlagents.trainers.tests.test_nn_policy import create_policy_mock
from mlagents.trainers.settings import TrainerSettings
from mlagents.tf_utils import tf
from mlagents.model_serialization import SerializationSettings
def test_barracuda_converter():

# cleanup
os.remove(tmpfile)
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_policy_conversion(tmpdir, rnn, visual, discrete):
tf.reset_default_graph()
dummy_config = TrainerSettings()
policy = create_policy_mock(
dummy_config,
use_rnn=rnn,
model_path=os.path.join(tmpdir, "test"),
use_discrete=discrete,
use_visual=visual,
)
settings = SerializationSettings(policy.model_path, "MockBrain")
checkpoint_path = f"{tmpdir}/MockBrain-1"
policy.checkpoint(checkpoint_path, settings)
# These checks taken from test_barracuda_converter
assert os.path.isfile(checkpoint_path + ".nn")
assert os.path.getsize(checkpoint_path + ".nn") > 100

10
ml-agents/mlagents/trainers/tests/test_bcmodule.py


NetworkSettings.MemorySettings() if use_rnn else None
)
policy = TFPolicy(
0,
mock_behavior_specs,
trainer_config,
"test",
False,
tanhresample,
tanhresample,
0, mock_behavior_specs, trainer_config, tanhresample, tanhresample
)
with policy.graph.as_default():
bc_module = BCModule(

default_num_epoch=3,
settings=bc_settings,
)
policy.initialize_or_load() # Normally the optimizer calls this after the BCModule is created
policy.initialize() # Normally the optimizer calls this after the BCModule is created
return bc_module

62
ml-agents/mlagents/trainers/tests/test_nn_policy.py


import pytest
import os
import unittest
import tempfile
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.tf.models import ModelUtils, Tensor3DShape
from mlagents.trainers.exception import UnityTrainerException

from mlagents.trainers import __version__
VECTOR_ACTION_SPACE = 2

use_rnn: bool = False,
use_discrete: bool = True,
use_visual: bool = False,
model_path: str = "",
load: bool = False,
seed: int = 0,
) -> TFPolicy:
mock_spec = mb.setup_test_behavior_specs(

trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings() if use_rnn else None
)
policy = TFPolicy(
seed, mock_spec, trainer_settings, model_path=model_path, load=load
)
policy = TFPolicy(seed, mock_spec, trainer_settings)
def test_load_save(tmp_path):
path1 = os.path.join(tmp_path, "runid1")
path2 = os.path.join(tmp_path, "runid2")
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params, model_path=path1)
policy.initialize_or_load()
policy._set_step(2000)
mock_brain_name = "MockBrain"
checkpoint_path = f"{policy.model_path}/{mock_brain_name}-2000"
serialization_settings = SerializationSettings(policy.model_path, mock_brain_name)
policy.checkpoint(checkpoint_path, serialization_settings)
assert len(os.listdir(tmp_path)) > 0
# Try load from this path
policy2 = create_policy_mock(trainer_params, model_path=path1, load=True, seed=1)
policy2.initialize_or_load()
_compare_two_policies(policy, policy2)
assert policy2.get_current_step() == 2000
# Try initialize from path 1
trainer_params.output_path = path2
trainer_params.init_path = path1
policy3 = create_policy_mock(trainer_params, model_path=path1, load=False, seed=2)
policy3.initialize_or_load()
_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0
class ModelVersionTest(unittest.TestCase):
def test_version_compare(self):
# Test write_stats
with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
path1 = tempfile.mkdtemp()
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params, model_path=path1)
policy.initialize_or_load()
policy._check_model_version(
"0.0.0"
) # This is not the right version for sure
# Assert that 1 warning has been thrown with incorrect version
assert len(cm.output) == 1
policy._check_model_version(__version__) # This should be the right version
# Assert that no additional warnings have been thrown wth correct ver
assert len(cm.output) == 1
def _compare_two_policies(policy1: TFPolicy, policy2: TFPolicy) -> None:

8
ml-agents/mlagents/trainers/tests/test_ppo.py


import attr
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.policy.tf_policy import TFPolicy

0, mock_specs, trainer_settings, "test", False, create_tf_graph=False
)
optimizer = PPOOptimizer(policy, trainer_settings)
policy.initialize()
return optimizer

)
@mock.patch.object(RLTrainer, "create_saver")
def test_trainer_increment_step(ppo_optimizer):
def test_trainer_increment_step(ppo_optimizer, mock_create_saver):
trainer_params = PPO_CONFIG
mock_optimizer = mock.Mock()
mock_optimizer.reward_signals = {}

assert trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").num > 0
@mock.patch.object(RLTrainer, "create_saver")
def test_add_get_policy(ppo_optimizer, dummy_config):
def test_add_get_policy(ppo_optimizer, mock_create_saver, dummy_config):
mock_optimizer = mock.Mock()
mock_optimizer.reward_signals = {}
ppo_optimizer.return_value = mock_optimizer

1
ml-agents/mlagents/trainers/tests/test_reward_signals.py


optimizer = SACOptimizer(policy, trainer_settings)
else:
optimizer = PPOOptimizer(policy, trainer_settings)
optimizer.policy.initialize()
return optimizer

21
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


import os
from unittest import mock
import pytest
import mlagents.trainers.tests.mock_brain as mb

return self.update_policy
def add_policy(self, mock_behavior_id, mock_policy):
def checkpoint_path(brain_name, step):
return os.path.join(self.saver.model_path, f"{brain_name}-{step}")
mock_saver = mock.Mock()
mock_saver.model_path = self.artifact_path
mock_saver.save_checkpoint.side_effect = checkpoint_path
self.saver = mock_saver
def create_policy(self):
return mock.Mock()

"test_trainer",
TrainerSettings(max_steps=100, checkpoint_interval=10, summary_freq=20),
True,
False,
"mock_model_path",
0,
)
trainer.set_is_policy_updating(True)

def test_advance(mocked_clear_update_buffer, mocked_save_model):
trainer = create_rl_trainer()
mock_policy = mock.Mock()
mock_policy.model_path = "mock_model_path"
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")

def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
trainer = create_rl_trainer()
mock_policy = mock.Mock()
mock_policy.model_path = "mock_model_path"
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")

checkpoint_range = range(
checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval
)
calls = [
mock.call(f"{mock_policy.model_path}/{trainer.brain_name}-{step}", mock.ANY)
for step in checkpoint_range
]
mock_policy.checkpoint.assert_has_calls(calls, any_order=True)
calls = [mock.call(trainer.brain_name, step) for step in checkpoint_range]
trainer.saver.save_checkpoint.assert_has_calls(calls, any_order=True)
add_checkpoint_calls = [
mock.call(

f"{mock_policy.model_path}/{trainer.brain_name}-{step}.nn",
f"{trainer.saver.model_path}/{trainer.brain_name}-{step}.nn",
None,
mock.ANY,
),

6
ml-agents/mlagents/trainers/tests/test_sac.py


from mlagents.tf_utils import tf
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.policy.tf_policy import TFPolicy

0, mock_brain, trainer_settings, "test", False, create_tf_graph=False
)
optimizer = SACOptimizer(policy, trainer_settings)
policy.initialize()
return optimizer

assert trainer2.update_buffer.num_experiences == buffer_len
@mock.patch.object(RLTrainer, "create_saver")
def test_add_get_policy(sac_optimizer, dummy_config):
def test_add_get_policy(sac_optimizer, mock_create_saver, dummy_config):
mock_optimizer = mock.Mock()
mock_optimizer.reward_signals = {}
sac_optimizer.return_value = mock_optimizer

policy = trainer.create_policy(behavior_id, specs)
policy.get_current_step = lambda: 200
trainer.add_policy(behavior_id, policy)
trainer.saver.initialize_or_load(policy)
trainer.optimizer.update = mock.Mock()
trainer.optimizer.update_reward_signals = mock.Mock()
trainer.optimizer.update_reward_signals.return_value = {}

8
ml-agents/mlagents/trainers/tests/test_simple_rl.py


# The reward processor is passed as an argument to _check_environment_trains.
# It is applied to the list pf all final rewards for each brain individually.
# It is applied to the list of all final rewards for each brain individually.
# Custom reward processors shuld be built within the test function and passed to _check_environment_trains
# Custom reward processors should be built within the test function and passed to _check_environment_trains
# Default is average over the last 5 final rewards
def default_reward_processor(rewards, last_n_rewards=5):
rewards_to_use = rewards[-last_n_rewards:]

@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_sac(use_discrete):
step_size = 0.2 if use_discrete else 1.0
step_size = 0.5 if use_discrete else 0.2
env = MemoryEnvironment(
[BRAIN_NAME], use_discrete=use_discrete, step_size=step_size
)

swap_steps=5000,
team_change=2000,
)
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=2000)
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=3000)
_check_environment_trains(
env, {BRAIN_NAME: config, brain_name_opp: config}, success_threshold=None
)

20
ml-agents/mlagents/trainers/tests/test_tf_policy.py


from mlagents.model_serialization import SerializationSettings
from unittest import mock
from mlagents.trainers.settings import TrainerSettings
import numpy as np

# Test dev versions
result = TFPolicy._convert_version_string("200.300.100.dev0")
assert result == (200, 300, 100)
@mock.patch("mlagents.trainers.policy.tf_policy.export_policy_model")
@mock.patch("time.time", mock.MagicMock(return_value=12345))
def test_checkpoint_writes_tf_and_nn_checkpoints(export_policy_model_mock):
mock_brain = basic_mock_brain()
test_seed = 4 # moving up in the world
policy = FakePolicy(test_seed, mock_brain, TrainerSettings(), "output")
n_steps = 5
policy.get_current_step = MagicMock(return_value=n_steps)
policy.saver = MagicMock()
serialization_settings = SerializationSettings("output", mock_brain.brain_name)
checkpoint_path = f"output/{mock_brain.brain_name}-{n_steps}"
policy.checkpoint(checkpoint_path, serialization_settings)
policy.saver.save.assert_called_once_with(policy.sess, f"{checkpoint_path}.ckpt")
export_policy_model_mock.assert_called_once_with(
checkpoint_path, serialization_settings, policy.graph, policy.sess
)

26
ml-agents/mlagents/trainers/trainer/rl_trainer.py


# # Unity ML-Agents Toolkit
import os
from mlagents.model_serialization import SerializationSettings, copy_model_files
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,

from mlagents_envs.timers import hierarchical_timer
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.saver.tf_saver import TFSaver
RewardSignalResults = Dict[str, RewardSignalResult]

)
self._next_save_step = 0
self._next_summary_step = 0
self.saver = self.create_saver(
self.trainer_settings, self.artifact_path, self.load
)
def end_episode(self) -> None:
"""

for agent_id in rewards:
rewards[agent_id] = 0
@staticmethod
def create_saver(
trainer_settings: TrainerSettings, model_path: str, load: bool
) -> BaseSaver:
saver = TFSaver(trainer_settings, model_path, load)
return saver
def _update_end_episode_stats(self, agent_id: str, optimizer: Optimizer) -> None:
for name, rewards in self.collected_rewards.items():
if name == "environment":

logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
model_path = policy.model_path
settings = SerializationSettings(model_path, self.brain_name)
checkpoint_path = os.path.join(model_path, f"{self.brain_name}-{self.step}")
policy.checkpoint(checkpoint_path, settings)
checkpoint_path = self.saver.save_checkpoint(self.brain_name, self.step)
new_checkpoint = NNCheckpoint(
int(self.step),
f"{checkpoint_path}.nn",

logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")
self.saver.copy_final_model(model_checkpoint.file_path)
model_checkpoint, file_path=f"{policy.model_path}.nn"
model_checkpoint, file_path=f"{self.saver.model_path}.nn"
)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)

2
ml-agents/mlagents/trainers/trainer/trainer.py


brain_name: str,
trainer_settings: TrainerSettings,
training: bool,
load: bool,
artifact_path: str,
reward_buff_cap: int = 1,
):

self._threaded = trainer_settings.threaded
self._stats_reporter = StatsReporter(brain_name)
self.is_training = training
self.load = load
self._reward_buffer: Deque[float] = deque(maxlen=reward_buff_cap)
self.policy_queues: List[AgentManagerQueue[Policy]] = []
self.trajectory_queues: List[AgentManagerQueue[Trajectory]] = []

113
ml-agents/mlagents/trainers/tests/test_saver.py


import pytest
from unittest import mock
import os
import unittest
import tempfile
import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.saver.tf_saver import TFSaver
from mlagents.trainers import __version__
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.test_nn_policy import create_policy_mock
from mlagents.trainers.ppo.optimizer import PPOOptimizer
def test_register(tmp_path):
trainer_params = TrainerSettings()
saver = TFSaver(trainer_params, tmp_path)
opt = mock.Mock(spec=PPOOptimizer)
saver.register(opt)
assert saver.policy is None
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params)
saver.register(policy)
assert saver.policy is not None
class ModelVersionTest(unittest.TestCase):
def test_version_compare(self):
# Test write_stats
with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
trainer_params = TrainerSettings()
mock_path = tempfile.mkdtemp()
policy = create_policy_mock(trainer_params)
saver = TFSaver(trainer_params, mock_path)
saver.register(policy)
saver._check_model_version(
"0.0.0"
) # This is not the right version for sure
# Assert that 1 warning has been thrown with incorrect version
assert len(cm.output) == 1
saver._check_model_version(__version__) # This should be the right version
# Assert that no additional warnings have been thrown wth correct ver
assert len(cm.output) == 1
def test_load_save(tmp_path):
path1 = os.path.join(tmp_path, "runid1")
path2 = os.path.join(tmp_path, "runid2")
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params)
saver = TFSaver(trainer_params, path1)
saver.register(policy)
saver.initialize_or_load(policy)
policy.set_step(2000)
mock_brain_name = "MockBrain"
saver.save_checkpoint(mock_brain_name, 2000)
assert len(os.listdir(tmp_path)) > 0
# Try load from this path
saver = TFSaver(trainer_params, path1, load=True)
policy2 = create_policy_mock(trainer_params)
saver.register(policy2)
saver.initialize_or_load(policy2)
_compare_two_policies(policy, policy2)
assert policy2.get_current_step() == 2000
# Try initialize from path 1
trainer_params.init_path = path1
saver = TFSaver(trainer_params, path2)
policy3 = create_policy_mock(trainer_params)
saver.register(policy3)
saver.initialize_or_load(policy3)
_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0
def _compare_two_policies(policy1: TFPolicy, policy2: TFPolicy) -> None:
"""
Make sure two policies have the same output for the same input.
"""
decision_step, _ = mb.create_steps_from_behavior_spec(
policy1.behavior_spec, num_agents=1
)
run_out1 = policy1.evaluate(decision_step, list(decision_step.agent_id))
run_out2 = policy2.evaluate(decision_step, list(decision_step.agent_id))
np.testing.assert_array_equal(run_out2["log_probs"], run_out1["log_probs"])
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_checkpoint_conversion(tmpdir, rnn, visual, discrete):
tf.reset_default_graph()
dummy_config = TrainerSettings()
model_path = os.path.join(tmpdir, "Mock_Brain")
policy = create_policy_mock(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
trainer_params = TrainerSettings()
saver = TFSaver(trainer_params, model_path)
saver.register(policy)
saver.save_checkpoint("Mock_Brain", 100)
assert os.path.isfile(model_path + "/Mock_Brain-100.nn")

221
ml-agents/mlagents/trainers/tf/model_serialization.py


from distutils.util import strtobool
import os
from typing import Any, List, Set
from distutils.version import LooseVersion
try:
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
from tf2onnx import optimizer
ONNX_EXPORT_ENABLED = True
except ImportError:
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow
ONNX_EXPORT_ENABLED = False
pass
from mlagents.tf_utils import tf
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.settings import SerializationSettings
from mlagents.trainers.tf import tensorflow_to_barracuda as tf2bc
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"):
# ONNX is only tested on 1.12.0 and later
ONNX_EXPORT_ENABLED = False
logger = get_logger(__name__)
POSSIBLE_INPUT_NODES = frozenset(
[
"action_masks",
"epsilon",
"prev_action",
"recurrent_in",
"sequence_length",
"vector_observation",
]
)
POSSIBLE_OUTPUT_NODES = frozenset(
["action", "action_probs", "recurrent_out", "value_estimate"]
)
MODEL_CONSTANTS = frozenset(
[
"action_output_shape",
"is_continuous_control",
"memory_size",
"version_number",
"trainer_major_version",
"trainer_minor_version",
"trainer_patch_version",
]
)
VISUAL_OBSERVATION_PREFIX = "visual_observation_"
def export_policy_model(
model_path: str,
output_filepath: str,
brain_name: str,
graph: tf.Graph,
sess: tf.Session,
) -> None:
"""
Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding.
:param output_filepath: file path to output the model (without file suffix)
:param brain_name: brain name of the trained model
:param graph: Tensorflow Graph for the policy
:param sess: Tensorflow session for the policy
"""
frozen_graph_def = _make_frozen_graph(brain_name, graph, sess)
if not os.path.exists(output_filepath):
os.makedirs(output_filepath)
# Save frozen graph
frozen_graph_def_path = model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(frozen_graph_def.SerializeToString())
# Convert to barracuda
if SerializationSettings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn")
logger.info(f"Exported {output_filepath}.nn")
# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED:
if SerializationSettings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(brain_name, frozen_graph_def)
onnx_output_path = f"{output_filepath}.onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
# Make conversion errors fatal depending on environment variables (only done during CI)
if _enforce_onnx_conversion():
raise
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
)
else:
if _enforce_onnx_conversion():
raise RuntimeError(
"ONNX conversion enforced, but couldn't import dependencies."
)
def _make_frozen_graph(
brain_name: str, graph: tf.Graph, sess: tf.Session
) -> tf.GraphDef:
with graph.as_default():
target_nodes = ",".join(_process_graph(brain_name, graph))
graph_def = graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, target_nodes.replace(" ", "").split(",")
)
return output_graph_def
def convert_frozen_to_onnx(brain_name: str, frozen_graph_def: tf.GraphDef) -> Any:
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
inputs = _get_input_node_names(frozen_graph_def)
outputs = _get_output_node_names(frozen_graph_def)
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")
frozen_graph_def = tf_optimize(
inputs, outputs, frozen_graph_def, fold_constant=True
)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(frozen_graph_def, name="")
with tf.Session(graph=tf_graph):
g = process_tf_graph(
tf_graph,
input_names=inputs,
output_names=outputs,
opset=SerializationSettings.onnx_opset,
)
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(brain_name)
return model_proto
def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of input node names from the graph.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
input_names = node_names & POSSIBLE_INPUT_NODES
# Check visual inputs sequentially, and exit as soon as we don't find one
vis_index = 0
while True:
vis_node_name = f"{VISUAL_OBSERVATION_PREFIX}{vis_index}"
if vis_node_name in node_names:
input_names.add(vis_node_name)
else:
break
vis_index += 1
# Append the port
return [f"{n}:0" for n in input_names]
def _get_output_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of output node names from the graph.
Also include constants, so that they will be readable by the
onnx importer.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS)
# Append the port
return [f"{n}:0" for n in output_names]
def _get_frozen_graph_node_names(frozen_graph_def: Any) -> Set[str]:
"""
Get all the node names from the graph.
"""
names = set()
for node in frozen_graph_def.node:
names.add(node.name)
return names
def _process_graph(brain_name: str, graph: tf.Graph) -> List[str]:
"""
Gets the list of the output nodes present in the graph for inference
:return: list of node names
"""
all_nodes = [x.name for x in graph.as_graph_def().node]
nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS]
logger.info("List of nodes to export for brain :" + brain_name)
for n in nodes:
logger.info("\t" + n)
return nodes
def _enforce_onnx_conversion() -> bool:
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION"
if env_var_name not in os.environ:
return False
val = os.environ[env_var_name]
try:
# This handles e.g. "false" converting reasonably to False
return strtobool(val)
except Exception:
return False

0
ml-agents/mlagents/trainers/saver/__init__.py

66
ml-agents/mlagents/trainers/saver/saver.py


# # Unity ML-Agents Toolkit
import abc
from typing import Any
class BaseSaver(abc.ABC):
"""This class is the base class for the Saver"""
def __init__(self):
pass
@abc.abstractmethod
def register(self, module: Any) -> None:
"""
Register the modules to the Saver.
The Saver will store the module and include it in the saved files
when saving checkpoint/exporting graph.
:param module: the module to be registered
"""
pass
def _register_policy(self, policy):
"""
Helper function for registering policy to the Saver.
:param policy: the policy to be registered
"""
pass
def _register_optimizer(self, optimizer):
"""
Helper function for registering optimizer to the Saver.
:param optimizer: the optimizer to be registered
"""
pass
@abc.abstractmethod
def save_checkpoint(self, brain_name: str, step: int) -> str:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param brain_name: Brain name of brain to be trained
"""
pass
@abc.abstractmethod
def export(self, output_filepath: str, brain_name: str) -> None:
"""
Saves the serialized model, given a path and brain name.
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param brain_name: Brain name of brain to be trained.
"""
pass
@abc.abstractmethod
def initialize_or_load(self, policy):
"""
Initialize/Load registered modules by default.
If given input argument policy, do with the input policy instead.
This argument is mainly for the initialization of the ghost trainer's fixed policy.
:param policy (optional): if given, perform the initializing/loading on this input policy.
Otherwise, do with the registered policy
"""
pass

170
ml-agents/mlagents/trainers/saver/tf_saver.py


import os
import shutil
from typing import Optional, Union, cast
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.tf_utils import tf
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.tf.model_serialization import export_policy_model
from mlagents.trainers.settings import TrainerSettings, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers import __version__
logger = get_logger(__name__)
class TFSaver(BaseSaver):
"""
Saver class for TensorFlow
"""
def __init__(
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False
):
super().__init__()
self.model_path = model_path
self.initialize_path = trainer_settings.init_path
self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
# Currently only support saving one policy. This is the one to be saved.
self.policy: Optional[TFPolicy] = None
self.graph = None
self.sess = None
self.tf_saver = None
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None:
if isinstance(module, TFPolicy):
self._register_policy(module)
elif isinstance(module, TFOptimizer):
self._register_optimizer(module)
else:
raise UnityPolicyException(
"Registering Object of unsupported type {} to Saver ".format(
type(module)
)
)
def _register_policy(self, policy: TFPolicy) -> None:
if self.policy is None:
self.policy = policy
self.graph = self.policy.graph
self.sess = self.policy.sess
with self.policy.graph.as_default():
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
def save_checkpoint(self, brain_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
# Save the TF checkpoint and graph definition
if self.graph:
with self.graph.as_default():
if self.tf_saver:
self.tf_saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, brain_name)
return checkpoint_path
def export(self, output_filepath: str, brain_name: str) -> None:
# save model if there is only one worker or
# only on worker-0 if there are multiple workers
if self.policy and self.policy.rank is not None and self.policy.rank != 0:
return
export_policy_model(
self.model_path, output_filepath, brain_name, self.graph, self.sess
)
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
if policy is None:
policy = self.policy
policy = cast(TFPolicy, policy)
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(
policy, self.initialize_path, reset_global_steps=reset_steps
)
elif self.load:
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps)
else:
policy.initialize()
TFPolicy.broadcast_global_variables(0)
def _load_graph(
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False
) -> None:
with policy.graph.as_default():
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
if self.tf_saver:
try:
self.tf_saver.restore(policy.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
self._check_model_version(__version__)
if reset_global_steps:
policy.set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(f"Resuming training from step {policy.get_current_step()}.")
def _check_model_version(self, version: str) -> None:
"""
Checks whether the model being loaded was created with the same version of
ML-Agents, and throw a warning if not so.
"""
if self.policy is not None and self.policy.version_tensors is not None:
loaded_ver = tuple(
num.eval(session=self.sess) for num in self.policy.version_tensors
)
if loaded_ver != TFPolicy._convert_version_string(version):
logger.warning(
f"The model checkpoint you are loading from was saved with ML-Agents version "
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
f"version is {version}. Model may not behave properly."
)
def copy_final_model(self, source_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
final_model_name = os.path.splitext(source_nn_path)[0]
if SerializationSettings.convert_to_barracuda:
source_path = f"{final_model_name}.nn"
destination_path = f"{self.model_path}.nn"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
if SerializationSettings.convert_to_onnx:
try:
source_path = f"{final_model_name}.onnx"
destination_path = f"{self.model_path}.onnx"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
except OSError:
pass

247
ml-agents/mlagents/model_serialization.py


from distutils.util import strtobool
import os
import shutil
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion
try:
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
from tf2onnx import optimizer
ONNX_EXPORT_ENABLED = True
except ImportError:
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow
ONNX_EXPORT_ENABLED = False
pass
from mlagents.tf_utils import tf
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.tf import tensorflow_to_barracuda as tf2bc
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"):
# ONNX is only tested on 1.12.0 and later
ONNX_EXPORT_ENABLED = False
logger = get_logger(__name__)
POSSIBLE_INPUT_NODES = frozenset(
[
"action_masks",
"epsilon",
"prev_action",
"recurrent_in",
"sequence_length",
"vector_observation",
]
)
POSSIBLE_OUTPUT_NODES = frozenset(
["action", "action_probs", "recurrent_out", "value_estimate"]
)
MODEL_CONSTANTS = frozenset(
[
"action_output_shape",
"is_continuous_control",
"memory_size",
"version_number",
"trainer_major_version",
"trainer_minor_version",
"trainer_patch_version",
]
)
VISUAL_OBSERVATION_PREFIX = "visual_observation_"
class SerializationSettings(NamedTuple):
model_path: str
brain_name: str
convert_to_barracuda: bool = True
convert_to_onnx: bool = True
onnx_opset: int = 9
def export_policy_model(
output_filepath: str,
settings: SerializationSettings,
graph: tf.Graph,
sess: tf.Session,
) -> None:
"""
Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding.
:param output_filepath: file path to output the model (without file suffix)
:param settings: SerializationSettings describing how to export the model
:param graph: Tensorflow Graph for the policy
:param sess: Tensorflow session for the policy
"""
frozen_graph_def = _make_frozen_graph(settings, graph, sess)
if not os.path.exists(settings.model_path):
os.makedirs(settings.model_path)
# Save frozen graph
frozen_graph_def_path = settings.model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(frozen_graph_def.SerializeToString())
# Convert to barracuda
if settings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn")
logger.info(f"Exported {output_filepath}.nn")
# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED:
if settings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
onnx_output_path = f"{output_filepath}.onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
# Make conversion errors fatal depending on environment variables (only done during CI)
if _enforce_onnx_conversion():
raise
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
)
else:
if _enforce_onnx_conversion():
raise RuntimeError(
"ONNX conversion enforced, but couldn't import dependencies."
)
def _make_frozen_graph(
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
) -> tf.GraphDef:
with graph.as_default():
target_nodes = ",".join(_process_graph(settings, graph))
graph_def = graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, target_nodes.replace(" ", "").split(",")
)
return output_graph_def
def convert_frozen_to_onnx(
settings: SerializationSettings, frozen_graph_def: tf.GraphDef
) -> Any:
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
inputs = _get_input_node_names(frozen_graph_def)
outputs = _get_output_node_names(frozen_graph_def)
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")
frozen_graph_def = tf_optimize(
inputs, outputs, frozen_graph_def, fold_constant=True
)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(frozen_graph_def, name="")
with tf.Session(graph=tf_graph):
g = process_tf_graph(
tf_graph,
input_names=inputs,
output_names=outputs,
opset=settings.onnx_opset,
)
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(settings.brain_name)
return model_proto
def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of input node names from the graph.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
input_names = node_names & POSSIBLE_INPUT_NODES
# Check visual inputs sequentially, and exit as soon as we don't find one
vis_index = 0
while True:
vis_node_name = f"{VISUAL_OBSERVATION_PREFIX}{vis_index}"
if vis_node_name in node_names:
input_names.add(vis_node_name)
else:
break
vis_index += 1
# Append the port
return [f"{n}:0" for n in input_names]
def _get_output_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of output node names from the graph.
Also include constants, so that they will be readable by the
onnx importer.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS)
# Append the port
return [f"{n}:0" for n in output_names]
def _get_frozen_graph_node_names(frozen_graph_def: Any) -> Set[str]:
"""
Get all the node names from the graph.
"""
names = set()
for node in frozen_graph_def.node:
names.add(node.name)
return names
def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str]:
"""
Gets the list of the output nodes present in the graph for inference
:return: list of node names
"""
all_nodes = [x.name for x in graph.as_graph_def().node]
nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS]
logger.info("List of nodes to export for brain :" + settings.brain_name)
for n in nodes:
logger.info("\t" + n)
return nodes
def _enforce_onnx_conversion() -> bool:
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION"
if env_var_name not in os.environ:
return False
val = os.environ[env_var_name]
try:
# This handles e.g. "false" converting reasonably to False
return strtobool(val)
except Exception:
return False
def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
shutil.copyfile(source_nn_path, destination_nn_path)
logger.info(f"Copied {source_nn_path} to {destination_nn_path}.")
# Copy the onnx file if it exists
source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx"
destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx"
try:
shutil.copyfile(source_onnx_path, destination_onnx_path)
logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.")
except OSError:
pass
正在加载...
取消
保存