|
|
|
|
|
|
def export_policy_model( |
|
|
|
model_path: str, |
|
|
|
output_filepath: str, |
|
|
|
brain_name: str, |
|
|
|
behavior_name: str, |
|
|
|
graph: tf.Graph, |
|
|
|
sess: tf.Session, |
|
|
|
) -> None: |
|
|
|
|
|
|
:param output_filepath: file path to output the model (without file suffix) |
|
|
|
:param brain_name: brain name of the trained model |
|
|
|
:param behavior_name: behavior name of the trained model |
|
|
|
frozen_graph_def = _make_frozen_graph(brain_name, graph, sess) |
|
|
|
frozen_graph_def = _make_frozen_graph(behavior_name, graph, sess) |
|
|
|
if not os.path.exists(output_filepath): |
|
|
|
os.makedirs(output_filepath) |
|
|
|
# Save frozen graph |
|
|
|
|
|
|
if ONNX_EXPORT_ENABLED: |
|
|
|
if SerializationSettings.convert_to_onnx: |
|
|
|
try: |
|
|
|
onnx_graph = convert_frozen_to_onnx(brain_name, frozen_graph_def) |
|
|
|
onnx_graph = convert_frozen_to_onnx(behavior_name, frozen_graph_def) |
|
|
|
onnx_output_path = f"{output_filepath}.onnx" |
|
|
|
with open(onnx_output_path, "wb") as f: |
|
|
|
f.write(onnx_graph.SerializeToString()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_frozen_graph( |
|
|
|
brain_name: str, graph: tf.Graph, sess: tf.Session |
|
|
|
behavior_name: str, graph: tf.Graph, sess: tf.Session |
|
|
|
target_nodes = ",".join(_process_graph(brain_name, graph)) |
|
|
|
target_nodes = ",".join(_process_graph(behavior_name, graph)) |
|
|
|
graph_def = graph.as_graph_def() |
|
|
|
output_graph_def = graph_util.convert_variables_to_constants( |
|
|
|
sess, graph_def, target_nodes.replace(" ", "").split(",") |
|
|
|
|
|
|
|
|
|
|
def convert_frozen_to_onnx(brain_name: str, frozen_graph_def: tf.GraphDef) -> Any: |
|
|
|
def convert_frozen_to_onnx(behavior_name: str, frozen_graph_def: tf.GraphDef) -> Any: |
|
|
|
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py |
|
|
|
|
|
|
|
inputs = _get_input_node_names(frozen_graph_def) |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
onnx_graph = optimizer.optimize_graph(g) |
|
|
|
model_proto = onnx_graph.make_model(brain_name) |
|
|
|
model_proto = onnx_graph.make_model(behavior_name) |
|
|
|
|
|
|
|
return model_proto |
|
|
|
|
|
|
|
|
|
|
return names |
|
|
|
|
|
|
|
|
|
|
|
def _process_graph(brain_name: str, graph: tf.Graph) -> List[str]: |
|
|
|
def _process_graph(behavior_name: str, graph: tf.Graph) -> List[str]: |
|
|
|
""" |
|
|
|
Gets the list of the output nodes present in the graph for inference |
|
|
|
:return: list of node names |
|
|
|
|
|
|
logger.info("List of nodes to export for brain :" + brain_name) |
|
|
|
logger.info("List of nodes to export for behavior :" + behavior_name) |
|
|
|
for n in nodes: |
|
|
|
logger.info("\t" + n) |
|
|
|
return nodes |
|
|
|