浏览代码

enforce onnx conversion (expect tf2 CI to fail) (#3600)

/release-0.15.0
Chris Elion 5 年前
当前提交
9c5fc33a
共有 6 个文件被更改,包括 54 次插入19 次删除
  1. 7
      .circleci/config.yml
  2. 2
      docs/Unity-Inference-Engine.md
  3. 55
      ml-agents/mlagents/model_serialization.py
  4. 3
      test_constraints_max_tf1_version.txt
  5. 2
      test_constraints_min_version.txt
  6. 4
      test_requirements.txt

7
.circleci/config.yml


pip_constraints:
type: string
description: Constraints file that is passed to "pip install". We constraint older versions of libraries for older python runtime, in order to help ensure compatibility.
enforce_onnx_conversion:
type: integer
default: 0
description: Whether to raise an exception if ONNX models couldn't be saved.
executor: << parameters.executor >>
working_directory: ~/repo

TEST_ENFORCE_ONNX_CONVERSION: << parameters.enforce_onnx_conversion >>
steps:
- checkout

pyversion: 3.7.3
# Test python 3.7 with the newest supported versions
pip_constraints: test_constraints_max_tf1_version.txt
# Make sure ONNX conversion passes here (recent version of tensorflow 1.x)
enforce_onnx_conversion: 1
- build_python:
name: python_3.7.3+tf2
executor: python373

2
docs/Unity-Inference-Engine.md


* ONNX (`.onnx`) files use an [industry-standard open format](https://onnx.ai/about.html) produced by the [tf2onnx package](https://github.com/onnx/tensorflow-onnx).
Export to ONNX is currently considered beta. To enable it, make sure `tf2onnx>=1.5.5` is installed in pip.
tf2onnx does not currently support tensorflow 2.0.0 or later.
tf2onnx does not currently support tensorflow 2.0.0 or later, or earlier than 1.12.0.
## Using the Unity Inference Engine

55
ml-agents/mlagents/model_serialization.py


from distutils.util import strtobool
import os
from distutils.version import LooseVersion
try:
import onnx

from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"):
# ONNX is only tested on 1.12.0 and later
ONNX_EXPORT_ENABLED = False
logger = logging.getLogger("mlagents.trainers")

logger.info(f"Exported {settings.model_path}.nn file")
# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED and settings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
onnx_output_path = settings.model_path + ".onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
if ONNX_EXPORT_ENABLED:
if settings.convert_to_onnx:
try:
onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def)
onnx_output_path = settings.model_path + ".onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")
except Exception:
# Make conversion errors fatal depending on environment variables (only done during CI)
if _enforce_onnx_conversion():
raise
logger.exception(
"Exception trying to save ONNX graph. Please report this error on "
"https://github.com/Unity-Technologies/ml-agents/issues and "
"attach a copy of frozen_graph_def.pb"
)
else:
if _enforce_onnx_conversion():
raise RuntimeError(
"ONNX conversion enforced, but couldn't import dependencies."
)

for n in nodes:
logger.info("\t" + n)
return nodes
def _enforce_onnx_conversion() -> bool:
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION"
if env_var_name not in os.environ:
return False
val = os.environ[env_var_name]
try:
# This handles e.g. "false" converting reasonably to False
return strtobool(val)
except Exception:
return False

3
test_constraints_max_tf1_version.txt


# For projects with upper bounds, we should periodically update this list to the latest release version
grpcio>=1.23.0
numpy>=1.17.2
# Temporary workaround for https://github.com/tensorflow/tensorflow/issues/36179 and https://github.com/tensorflow/tensorflow/issues/36188
tensorflow>=1.14.0,<1.15.1
tensorflow>=1.15.2,<2.0.0
h5py>=2.10.0

2
test_constraints_min_version.txt


numpy==1.14.1
Pillow==4.2.1
protobuf==3.6
tensorflow==1.7
tensorflow==1.7.0
h5py==2.9.0

4
test_requirements.txt


pytest-cov==2.6.1
pytest-xdist
# Tests install onnx and tf2onnx, but this doesn't support tensorflow>=2.0.0
# Since we test tensorflow2.0 with python3.7, exclude it based on the python version
tf2onnx>=1.5.5; python_version < '3.7'
tf2onnx>=1.5.5
正在加载...
取消
保存