|
|
|
|
|
|
import tensorflow as tf |
|
|
|
from distutils.version import LooseVersion |
|
|
|
# TODO better version check, this will do for now though |
|
|
|
is_tf2 = tf.__version__ == "2.0.0" |
|
|
|
|
|
|
|
# LooseVersion handles things "1.2.3a" or "4.5.6-rc7" fairly sensibly. |
|
|
|
_is_tensorflow2 = LooseVersion(tf.__version__) >= LooseVersion("2.0.0") |
|
|
|
|
|
|
|
# A few things that we use live in different places between tensorflow 1.x and 2.x |
|
|
|
# If anything new is added, please add it here |
|
|
|
if is_tf2: |
|
|
|
if _is_tensorflow2: |
|
|
|
import tensorflow.compat.v1 as tf |
|
|
|
|
|
|
|
tf_variance_scaling = tf.initializers.variance_scaling |
|
|
|