|
|
|
|
|
|
class SerializationSettings(NamedTuple): |
|
|
|
model_path: str |
|
|
|
brain_name: str |
|
|
|
checkpoint_path: str = "" |
|
|
|
convert_to_barracuda: bool = True |
|
|
|
convert_to_onnx: bool = True |
|
|
|
onnx_opset: int = 9 |
|
|
|
|
|
|
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session |
|
|
|
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session, is_checkpoint: bool = False |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Exports latest saved model to .nn format for Unity embedding. |
|
|
|
|
|
|
|
|
|
|
# Convert to barracuda |
|
|
|
if settings.convert_to_barracuda: |
|
|
|
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn") |
|
|
|
logger.info(f"Exported {settings.model_path}.nn file") |
|
|
|
if is_checkpoint : |
|
|
|
tf2bc.convert(frozen_graph_def_path, os.path.join(settings.model_path, f"{settings.checkpoint_path}.nn")) |
|
|
|
logger.info(f"Exported {settings.checkpoint_path}.nn file") |
|
|
|
else: |
|
|
|
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn") |
|
|
|
logger.info(f"Exported {settings.model_path}.nn file") |
|
|
|
|
|
|
|
# Save to onnx too (if we were able to import it) |
|
|
|
if ONNX_EXPORT_ENABLED: |
|
|
|