|
|
|
|
|
|
from distutils.version import LooseVersion |
|
|
|
|
|
|
|
try: |
|
|
|
import onnx |
|
|
|
from tf2onnx.tfonnx import process_tf_graph, tf_optimize |
|
|
|
from tf2onnx import optimizer |
|
|
|
|
|
|
|
|
|
|
) -> Any: |
|
|
|
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py |
|
|
|
|
|
|
|
# Some constants in the graph need to be read by the inference system. |
|
|
|
# These aren't used by the model anywhere, so trying to make sure they propagate |
|
|
|
# through conversion and import is a losing battle. Instead, save them now, |
|
|
|
# so that we can add them back later. |
|
|
|
constant_values = {} |
|
|
|
for n in frozen_graph_def.node: |
|
|
|
if n.name in MODEL_CONSTANTS: |
|
|
|
val = n.attr["value"].tensor.int_val[0] |
|
|
|
constant_values[n.name] = val |
|
|
|
|
|
|
|
inputs = _get_input_node_names(frozen_graph_def) |
|
|
|
outputs = _get_output_node_names(frozen_graph_def) |
|
|
|
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}") |
|
|
|
|
|
|
onnx_graph = optimizer.optimize_graph(g) |
|
|
|
model_proto = onnx_graph.make_model(settings.brain_name) |
|
|
|
|
|
|
|
# Save the constant values back the graph initializer. |
|
|
|
# This will ensure the importer gets them as global constants. |
|
|
|
constant_nodes = [] |
|
|
|
for k, v in constant_values.items(): |
|
|
|
constant_node = _make_onnx_node_for_constant(k, v) |
|
|
|
constant_nodes.append(constant_node) |
|
|
|
model_proto.graph.initializer.extend(constant_nodes) |
|
|
|
|
|
|
|
|
|
|
|
def _make_onnx_node_for_constant(name: str, value: int) -> Any: |
|
|
|
tensor_value = onnx.TensorProto( |
|
|
|
data_type=onnx.TensorProto.INT32, |
|
|
|
name=name, |
|
|
|
int32_data=[value], |
|
|
|
dims=[1, 1, 1, 1], |
|
|
|
) |
|
|
|
return tensor_value |
|
|
|
|
|
|
|
|
|
|
|
def _get_input_node_names(frozen_graph_def: Any) -> List[str]: |
|
|
|
|
|
|
def _get_output_node_names(frozen_graph_def: Any) -> List[str]: |
|
|
|
""" |
|
|
|
Get the list of output node names from the graph. |
|
|
|
Also include constants, so that they will be readable by the |
|
|
|
onnx importer. |
|
|
|
output_names = node_names & POSSIBLE_OUTPUT_NODES |
|
|
|
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS) |
|
|
|
# Append the port |
|
|
|
return [f"{n}:0" for n in output_names] |
|
|
|
|
|
|
|