GitHub
4 年前
当前提交
25dc8c3d
共有 27 个文件被更改,包括 650 次插入 和 525 次删除
-
8ml-agents-envs/mlagents_envs/exception.py
-
2ml-agents/mlagents/trainers/ghost/trainer.py
-
15ml-agents/mlagents/trainers/policy/policy.py
-
112ml-agents/mlagents/trainers/policy/tf_policy.py
-
2ml-agents/mlagents/trainers/ppo/optimizer.py
-
10ml-agents/mlagents/trainers/ppo/trainer.py
-
2ml-agents/mlagents/trainers/sac/optimizer.py
-
10ml-agents/mlagents/trainers/sac/trainer.py
-
6ml-agents/mlagents/trainers/settings.py
-
27ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
-
10ml-agents/mlagents/trainers/tests/test_bcmodule.py
-
62ml-agents/mlagents/trainers/tests/test_nn_policy.py
-
8ml-agents/mlagents/trainers/tests/test_ppo.py
-
1ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
21ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
6ml-agents/mlagents/trainers/tests/test_sac.py
-
8ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
20ml-agents/mlagents/trainers/tests/test_tf_policy.py
-
26ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
2ml-agents/mlagents/trainers/trainer/trainer.py
-
113ml-agents/mlagents/trainers/tests/test_saver.py
-
221ml-agents/mlagents/trainers/tf/model_serialization.py
-
0ml-agents/mlagents/trainers/saver/__init__.py
-
66ml-agents/mlagents/trainers/saver/saver.py
-
170ml-agents/mlagents/trainers/saver/tf_saver.py
-
247ml-agents/mlagents/model_serialization.py
|
|||
import pytest |
|||
from unittest import mock |
|||
import os |
|||
import unittest |
|||
import tempfile |
|||
|
|||
import numpy as np |
|||
from mlagents.tf_utils import tf |
|||
from mlagents.trainers.saver.tf_saver import TFSaver |
|||
from mlagents.trainers import __version__ |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
from mlagents.trainers.tests.test_nn_policy import create_policy_mock |
|||
from mlagents.trainers.ppo.optimizer import PPOOptimizer |
|||
|
|||
|
|||
def test_register(tmp_path): |
|||
trainer_params = TrainerSettings() |
|||
saver = TFSaver(trainer_params, tmp_path) |
|||
|
|||
opt = mock.Mock(spec=PPOOptimizer) |
|||
saver.register(opt) |
|||
assert saver.policy is None |
|||
|
|||
trainer_params = TrainerSettings() |
|||
policy = create_policy_mock(trainer_params) |
|||
saver.register(policy) |
|||
assert saver.policy is not None |
|||
|
|||
|
|||
class ModelVersionTest(unittest.TestCase): |
|||
def test_version_compare(self): |
|||
# Test write_stats |
|||
with self.assertLogs("mlagents.trainers", level="WARNING") as cm: |
|||
trainer_params = TrainerSettings() |
|||
mock_path = tempfile.mkdtemp() |
|||
policy = create_policy_mock(trainer_params) |
|||
saver = TFSaver(trainer_params, mock_path) |
|||
saver.register(policy) |
|||
|
|||
saver._check_model_version( |
|||
"0.0.0" |
|||
) # This is not the right version for sure |
|||
# Assert that 1 warning has been thrown with incorrect version |
|||
assert len(cm.output) == 1 |
|||
saver._check_model_version(__version__) # This should be the right version |
|||
# Assert that no additional warnings have been thrown wth correct ver |
|||
assert len(cm.output) == 1 |
|||
|
|||
|
|||
def test_load_save(tmp_path): |
|||
path1 = os.path.join(tmp_path, "runid1") |
|||
path2 = os.path.join(tmp_path, "runid2") |
|||
trainer_params = TrainerSettings() |
|||
policy = create_policy_mock(trainer_params) |
|||
saver = TFSaver(trainer_params, path1) |
|||
saver.register(policy) |
|||
saver.initialize_or_load(policy) |
|||
policy.set_step(2000) |
|||
|
|||
mock_brain_name = "MockBrain" |
|||
saver.save_checkpoint(mock_brain_name, 2000) |
|||
assert len(os.listdir(tmp_path)) > 0 |
|||
|
|||
# Try load from this path |
|||
saver = TFSaver(trainer_params, path1, load=True) |
|||
policy2 = create_policy_mock(trainer_params) |
|||
saver.register(policy2) |
|||
saver.initialize_or_load(policy2) |
|||
_compare_two_policies(policy, policy2) |
|||
assert policy2.get_current_step() == 2000 |
|||
|
|||
# Try initialize from path 1 |
|||
trainer_params.init_path = path1 |
|||
saver = TFSaver(trainer_params, path2) |
|||
policy3 = create_policy_mock(trainer_params) |
|||
saver.register(policy3) |
|||
saver.initialize_or_load(policy3) |
|||
|
|||
_compare_two_policies(policy2, policy3) |
|||
# Assert that the steps are 0. |
|||
assert policy3.get_current_step() == 0 |
|||
|
|||
|
|||
def _compare_two_policies(policy1: TFPolicy, policy2: TFPolicy) -> None: |
|||
""" |
|||
Make sure two policies have the same output for the same input. |
|||
""" |
|||
decision_step, _ = mb.create_steps_from_behavior_spec( |
|||
policy1.behavior_spec, num_agents=1 |
|||
) |
|||
run_out1 = policy1.evaluate(decision_step, list(decision_step.agent_id)) |
|||
run_out2 = policy2.evaluate(decision_step, list(decision_step.agent_id)) |
|||
|
|||
np.testing.assert_array_equal(run_out2["log_probs"], run_out1["log_probs"]) |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_checkpoint_conversion(tmpdir, rnn, visual, discrete): |
|||
tf.reset_default_graph() |
|||
dummy_config = TrainerSettings() |
|||
model_path = os.path.join(tmpdir, "Mock_Brain") |
|||
policy = create_policy_mock( |
|||
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
trainer_params = TrainerSettings() |
|||
saver = TFSaver(trainer_params, model_path) |
|||
saver.register(policy) |
|||
saver.save_checkpoint("Mock_Brain", 100) |
|||
assert os.path.isfile(model_path + "/Mock_Brain-100.nn") |
|
|||
from distutils.util import strtobool |
|||
import os |
|||
from typing import Any, List, Set |
|||
from distutils.version import LooseVersion |
|||
|
|||
try: |
|||
from tf2onnx.tfonnx import process_tf_graph, tf_optimize |
|||
from tf2onnx import optimizer |
|||
|
|||
ONNX_EXPORT_ENABLED = True |
|||
except ImportError: |
|||
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow |
|||
ONNX_EXPORT_ENABLED = False |
|||
pass |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from tensorflow.python.platform import gfile |
|||
from tensorflow.python.framework import graph_util |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers.settings import SerializationSettings |
|||
from mlagents.trainers.tf import tensorflow_to_barracuda as tf2bc |
|||
|
|||
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"): |
|||
# ONNX is only tested on 1.12.0 and later |
|||
ONNX_EXPORT_ENABLED = False |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
POSSIBLE_INPUT_NODES = frozenset( |
|||
[ |
|||
"action_masks", |
|||
"epsilon", |
|||
"prev_action", |
|||
"recurrent_in", |
|||
"sequence_length", |
|||
"vector_observation", |
|||
] |
|||
) |
|||
|
|||
POSSIBLE_OUTPUT_NODES = frozenset( |
|||
["action", "action_probs", "recurrent_out", "value_estimate"] |
|||
) |
|||
|
|||
MODEL_CONSTANTS = frozenset( |
|||
[ |
|||
"action_output_shape", |
|||
"is_continuous_control", |
|||
"memory_size", |
|||
"version_number", |
|||
"trainer_major_version", |
|||
"trainer_minor_version", |
|||
"trainer_patch_version", |
|||
] |
|||
) |
|||
VISUAL_OBSERVATION_PREFIX = "visual_observation_" |
|||
|
|||
|
|||
def export_policy_model( |
|||
model_path: str, |
|||
output_filepath: str, |
|||
brain_name: str, |
|||
graph: tf.Graph, |
|||
sess: tf.Session, |
|||
) -> None: |
|||
""" |
|||
Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding. |
|||
|
|||
:param output_filepath: file path to output the model (without file suffix) |
|||
:param brain_name: brain name of the trained model |
|||
:param graph: Tensorflow Graph for the policy |
|||
:param sess: Tensorflow session for the policy |
|||
""" |
|||
frozen_graph_def = _make_frozen_graph(brain_name, graph, sess) |
|||
if not os.path.exists(output_filepath): |
|||
os.makedirs(output_filepath) |
|||
# Save frozen graph |
|||
frozen_graph_def_path = model_path + "/frozen_graph_def.pb" |
|||
with gfile.GFile(frozen_graph_def_path, "wb") as f: |
|||
f.write(frozen_graph_def.SerializeToString()) |
|||
|
|||
# Convert to barracuda |
|||
if SerializationSettings.convert_to_barracuda: |
|||
tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn") |
|||
logger.info(f"Exported {output_filepath}.nn") |
|||
|
|||
# Save to onnx too (if we were able to import it) |
|||
if ONNX_EXPORT_ENABLED: |
|||
if SerializationSettings.convert_to_onnx: |
|||
try: |
|||
onnx_graph = convert_frozen_to_onnx(brain_name, frozen_graph_def) |
|||
onnx_output_path = f"{output_filepath}.onnx" |
|||
with open(onnx_output_path, "wb") as f: |
|||
f.write(onnx_graph.SerializeToString()) |
|||
logger.info(f"Converting to {onnx_output_path}") |
|||
except Exception: |
|||
# Make conversion errors fatal depending on environment variables (only done during CI) |
|||
if _enforce_onnx_conversion(): |
|||
raise |
|||
logger.exception( |
|||
"Exception trying to save ONNX graph. Please report this error on " |
|||
"https://github.com/Unity-Technologies/ml-agents/issues and " |
|||
"attach a copy of frozen_graph_def.pb" |
|||
) |
|||
|
|||
else: |
|||
if _enforce_onnx_conversion(): |
|||
raise RuntimeError( |
|||
"ONNX conversion enforced, but couldn't import dependencies." |
|||
) |
|||
|
|||
|
|||
def _make_frozen_graph( |
|||
brain_name: str, graph: tf.Graph, sess: tf.Session |
|||
) -> tf.GraphDef: |
|||
with graph.as_default(): |
|||
target_nodes = ",".join(_process_graph(brain_name, graph)) |
|||
graph_def = graph.as_graph_def() |
|||
output_graph_def = graph_util.convert_variables_to_constants( |
|||
sess, graph_def, target_nodes.replace(" ", "").split(",") |
|||
) |
|||
return output_graph_def |
|||
|
|||
|
|||
def convert_frozen_to_onnx(brain_name: str, frozen_graph_def: tf.GraphDef) -> Any: |
|||
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py |
|||
|
|||
inputs = _get_input_node_names(frozen_graph_def) |
|||
outputs = _get_output_node_names(frozen_graph_def) |
|||
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}") |
|||
|
|||
frozen_graph_def = tf_optimize( |
|||
inputs, outputs, frozen_graph_def, fold_constant=True |
|||
) |
|||
|
|||
with tf.Graph().as_default() as tf_graph: |
|||
tf.import_graph_def(frozen_graph_def, name="") |
|||
with tf.Session(graph=tf_graph): |
|||
g = process_tf_graph( |
|||
tf_graph, |
|||
input_names=inputs, |
|||
output_names=outputs, |
|||
opset=SerializationSettings.onnx_opset, |
|||
) |
|||
|
|||
onnx_graph = optimizer.optimize_graph(g) |
|||
model_proto = onnx_graph.make_model(brain_name) |
|||
|
|||
return model_proto |
|||
|
|||
|
|||
def _get_input_node_names(frozen_graph_def: Any) -> List[str]: |
|||
""" |
|||
Get the list of input node names from the graph. |
|||
Names are suffixed with ":0" |
|||
""" |
|||
node_names = _get_frozen_graph_node_names(frozen_graph_def) |
|||
input_names = node_names & POSSIBLE_INPUT_NODES |
|||
|
|||
# Check visual inputs sequentially, and exit as soon as we don't find one |
|||
vis_index = 0 |
|||
while True: |
|||
vis_node_name = f"{VISUAL_OBSERVATION_PREFIX}{vis_index}" |
|||
if vis_node_name in node_names: |
|||
input_names.add(vis_node_name) |
|||
else: |
|||
break |
|||
vis_index += 1 |
|||
# Append the port |
|||
return [f"{n}:0" for n in input_names] |
|||
|
|||
|
|||
def _get_output_node_names(frozen_graph_def: Any) -> List[str]: |
|||
""" |
|||
Get the list of output node names from the graph. |
|||
Also include constants, so that they will be readable by the |
|||
onnx importer. |
|||
Names are suffixed with ":0" |
|||
""" |
|||
node_names = _get_frozen_graph_node_names(frozen_graph_def) |
|||
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS) |
|||
# Append the port |
|||
return [f"{n}:0" for n in output_names] |
|||
|
|||
|
|||
def _get_frozen_graph_node_names(frozen_graph_def: Any) -> Set[str]: |
|||
""" |
|||
Get all the node names from the graph. |
|||
""" |
|||
names = set() |
|||
for node in frozen_graph_def.node: |
|||
names.add(node.name) |
|||
return names |
|||
|
|||
|
|||
def _process_graph(brain_name: str, graph: tf.Graph) -> List[str]: |
|||
""" |
|||
Gets the list of the output nodes present in the graph for inference |
|||
:return: list of node names |
|||
""" |
|||
all_nodes = [x.name for x in graph.as_graph_def().node] |
|||
nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS] |
|||
logger.info("List of nodes to export for brain :" + brain_name) |
|||
for n in nodes: |
|||
logger.info("\t" + n) |
|||
return nodes |
|||
|
|||
|
|||
def _enforce_onnx_conversion() -> bool: |
|||
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION" |
|||
if env_var_name not in os.environ: |
|||
return False |
|||
|
|||
val = os.environ[env_var_name] |
|||
try: |
|||
# This handles e.g. "false" converting reasonably to False |
|||
return strtobool(val) |
|||
except Exception: |
|||
return False |
|
|||
# # Unity ML-Agents Toolkit |
|||
import abc |
|||
from typing import Any |
|||
|
|||
|
|||
class BaseSaver(abc.ABC): |
|||
"""This class is the base class for the Saver""" |
|||
|
|||
def __init__(self): |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def register(self, module: Any) -> None: |
|||
""" |
|||
Register the modules to the Saver. |
|||
The Saver will store the module and include it in the saved files |
|||
when saving checkpoint/exporting graph. |
|||
:param module: the module to be registered |
|||
""" |
|||
pass |
|||
|
|||
def _register_policy(self, policy): |
|||
""" |
|||
Helper function for registering policy to the Saver. |
|||
:param policy: the policy to be registered |
|||
""" |
|||
pass |
|||
|
|||
def _register_optimizer(self, optimizer): |
|||
""" |
|||
Helper function for registering optimizer to the Saver. |
|||
:param optimizer: the optimizer to be registered |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def save_checkpoint(self, brain_name: str, step: int) -> str: |
|||
""" |
|||
Checkpoints the policy on disk. |
|||
:param checkpoint_path: filepath to write the checkpoint |
|||
:param brain_name: Brain name of brain to be trained |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def export(self, output_filepath: str, brain_name: str) -> None: |
|||
""" |
|||
Saves the serialized model, given a path and brain name. |
|||
This method will save the policy graph to the given filepath. The path |
|||
should be provided without an extension as multiple serialized model formats |
|||
may be generated as a result. |
|||
:param output_filepath: path (without suffix) for the model file(s) |
|||
:param brain_name: Brain name of brain to be trained. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def initialize_or_load(self, policy): |
|||
""" |
|||
Initialize/Load registered modules by default. |
|||
If given input argument policy, do with the input policy instead. |
|||
This argument is mainly for the initialization of the ghost trainer's fixed policy. |
|||
:param policy (optional): if given, perform the initializing/loading on this input policy. |
|||
Otherwise, do with the registered policy |
|||
""" |
|||
pass |
|
|||
import os |
|||
import shutil |
|||
from typing import Optional, Union, cast |
|||
from mlagents_envs.exception import UnityPolicyException |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.tf_utils import tf |
|||
from mlagents.trainers.saver.saver import BaseSaver |
|||
from mlagents.trainers.tf.model_serialization import export_policy_model |
|||
from mlagents.trainers.settings import TrainerSettings, SerializationSettings |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|||
from mlagents.trainers import __version__ |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TFSaver(BaseSaver): |
|||
""" |
|||
Saver class for TensorFlow |
|||
""" |
|||
|
|||
def __init__( |
|||
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False |
|||
): |
|||
super().__init__() |
|||
self.model_path = model_path |
|||
self.initialize_path = trainer_settings.init_path |
|||
self._keep_checkpoints = trainer_settings.keep_checkpoints |
|||
self.load = load |
|||
|
|||
# Currently only support saving one policy. This is the one to be saved. |
|||
self.policy: Optional[TFPolicy] = None |
|||
self.graph = None |
|||
self.sess = None |
|||
self.tf_saver = None |
|||
|
|||
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None: |
|||
if isinstance(module, TFPolicy): |
|||
self._register_policy(module) |
|||
elif isinstance(module, TFOptimizer): |
|||
self._register_optimizer(module) |
|||
else: |
|||
raise UnityPolicyException( |
|||
"Registering Object of unsupported type {} to Saver ".format( |
|||
type(module) |
|||
) |
|||
) |
|||
|
|||
def _register_policy(self, policy: TFPolicy) -> None: |
|||
if self.policy is None: |
|||
self.policy = policy |
|||
self.graph = self.policy.graph |
|||
self.sess = self.policy.sess |
|||
with self.policy.graph.as_default(): |
|||
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints) |
|||
|
|||
def save_checkpoint(self, brain_name: str, step: int) -> str: |
|||
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}") |
|||
# Save the TF checkpoint and graph definition |
|||
if self.graph: |
|||
with self.graph.as_default(): |
|||
if self.tf_saver: |
|||
self.tf_saver.save(self.sess, f"{checkpoint_path}.ckpt") |
|||
tf.train.write_graph( |
|||
self.graph, self.model_path, "raw_graph_def.pb", as_text=False |
|||
) |
|||
# also save the policy so we have optimized model files for each checkpoint |
|||
self.export(checkpoint_path, brain_name) |
|||
return checkpoint_path |
|||
|
|||
def export(self, output_filepath: str, brain_name: str) -> None: |
|||
# save model if there is only one worker or |
|||
# only on worker-0 if there are multiple workers |
|||
if self.policy and self.policy.rank is not None and self.policy.rank != 0: |
|||
return |
|||
export_policy_model( |
|||
self.model_path, output_filepath, brain_name, self.graph, self.sess |
|||
) |
|||
|
|||
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None: |
|||
# If there is an initialize path, load from that. Else, load from the set model path. |
|||
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to, |
|||
# e.g., resume from an initialize path. |
|||
if policy is None: |
|||
policy = self.policy |
|||
policy = cast(TFPolicy, policy) |
|||
reset_steps = not self.load |
|||
if self.initialize_path is not None: |
|||
self._load_graph( |
|||
policy, self.initialize_path, reset_global_steps=reset_steps |
|||
) |
|||
elif self.load: |
|||
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps) |
|||
else: |
|||
policy.initialize() |
|||
TFPolicy.broadcast_global_variables(0) |
|||
|
|||
def _load_graph( |
|||
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False |
|||
) -> None: |
|||
with policy.graph.as_default(): |
|||
logger.info(f"Loading model from {model_path}.") |
|||
ckpt = tf.train.get_checkpoint_state(model_path) |
|||
if ckpt is None: |
|||
raise UnityPolicyException( |
|||
"The model {} could not be loaded. Make " |
|||
"sure you specified the right " |
|||
"--run-id and that the previous run you are loading from had the same " |
|||
"behavior names.".format(model_path) |
|||
) |
|||
if self.tf_saver: |
|||
try: |
|||
self.tf_saver.restore(policy.sess, ckpt.model_checkpoint_path) |
|||
except tf.errors.NotFoundError: |
|||
raise UnityPolicyException( |
|||
"The model {} was found but could not be loaded. Make " |
|||
"sure the model is from the same version of ML-Agents, has the same behavior parameters, " |
|||
"and is using the same trainer configuration as the current run.".format( |
|||
model_path |
|||
) |
|||
) |
|||
self._check_model_version(__version__) |
|||
if reset_global_steps: |
|||
policy.set_step(0) |
|||
logger.info( |
|||
"Starting training from step 0 and saving to {}.".format( |
|||
self.model_path |
|||
) |
|||
) |
|||
else: |
|||
logger.info(f"Resuming training from step {policy.get_current_step()}.") |
|||
|
|||
def _check_model_version(self, version: str) -> None: |
|||
""" |
|||
Checks whether the model being loaded was created with the same version of |
|||
ML-Agents, and throw a warning if not so. |
|||
""" |
|||
if self.policy is not None and self.policy.version_tensors is not None: |
|||
loaded_ver = tuple( |
|||
num.eval(session=self.sess) for num in self.policy.version_tensors |
|||
) |
|||
if loaded_ver != TFPolicy._convert_version_string(version): |
|||
logger.warning( |
|||
f"The model checkpoint you are loading from was saved with ML-Agents version " |
|||
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents" |
|||
f"version is {version}. Model may not behave properly." |
|||
) |
|||
|
|||
def copy_final_model(self, source_nn_path: str) -> None: |
|||
""" |
|||
Copy the .nn file at the given source to the destination. |
|||
Also copies the corresponding .onnx file if it exists. |
|||
""" |
|||
final_model_name = os.path.splitext(source_nn_path)[0] |
|||
|
|||
if SerializationSettings.convert_to_barracuda: |
|||
source_path = f"{final_model_name}.nn" |
|||
destination_path = f"{self.model_path}.nn" |
|||
shutil.copyfile(source_path, destination_path) |
|||
logger.info(f"Copied {source_path} to {destination_path}.") |
|||
|
|||
if SerializationSettings.convert_to_onnx: |
|||
try: |
|||
source_path = f"{final_model_name}.onnx" |
|||
destination_path = f"{self.model_path}.onnx" |
|||
shutil.copyfile(source_path, destination_path) |
|||
logger.info(f"Copied {source_path} to {destination_path}.") |
|||
except OSError: |
|||
pass |
|
|||
from distutils.util import strtobool |
|||
import os |
|||
import shutil |
|||
from typing import Any, List, Set, NamedTuple |
|||
from distutils.version import LooseVersion |
|||
|
|||
try: |
|||
from tf2onnx.tfonnx import process_tf_graph, tf_optimize |
|||
from tf2onnx import optimizer |
|||
|
|||
ONNX_EXPORT_ENABLED = True |
|||
except ImportError: |
|||
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow |
|||
ONNX_EXPORT_ENABLED = False |
|||
pass |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from tensorflow.python.platform import gfile |
|||
from tensorflow.python.framework import graph_util |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||