浏览代码

fix imports for 1.14

/develop-gpu-test
Chris Elion 5 年前
当前提交
1aa07606
共有 1 个文件被更改,包括 4 次插入19 次删除
  1. 23
      ml-agents/mlagents/trainers/tf.py

23
ml-agents/mlagents/trainers/tf.py


import logging
import tensorflow as tf
def warnings_as_errors(log_record):
# Raise deprecated warnings as exceptions.
if log_record.levelno == logging.WARNING and "deprecated" in log_record.msg:
merged = log_record.getMessage()
raise RuntimeError(merged)
return True
# TODO better version check, this will do for now though
is_tf2 = tf.__version__ == "2.0.0"
# TODO only enable this with a environment variable
if False:
logging.getLogger("tensorflow").addFilter(warnings_as_errors)
try:
if is_tf2:
except ImportError:
import tensorflow as tf
# TODO better version check, this will do for now though
is_tf2 = tf.__version__ == "2.0.0"
if is_tf2:
tf_variance_scaling = tf.initializers.variance_scaling
tf_flatten = tf.layers.flatten
tf_rnn = tf.nn.rnn_cell

正在加载...
取消
保存