浏览代码

version check

/develop-gpu-test
Chris Elion 5 年前
当前提交
73acf8a7
共有 1 个文件被更改,包括 8 次插入3 次删除
  1. 11
      ml-agents/mlagents/trainers/tf.py

11
ml-agents/mlagents/trainers/tf.py


import tensorflow as tf
from distutils.version import LooseVersion
# TODO better version check, this will do for now though
is_tf2 = tf.__version__ == "2.0.0"
# LooseVersion handles things "1.2.3a" or "4.5.6-rc7" fairly sensibly.
_is_tensorflow2 = LooseVersion(tf.__version__) >= LooseVersion("2.0.0")
# A few things that we use live in different places between tensorflow 1.x and 2.x
# If anything new is added, please add it here
if is_tf2:
if _is_tensorflow2:
import tensorflow.compat.v1 as tf
tf_variance_scaling = tf.initializers.variance_scaling

正在加载...
取消
保存