浏览代码

move compat functions

/develop-gpu-test
Chris Elion 5 年前
当前提交
8da16bdb
共有 5 个文件被更改,包括 17 次插入34 次删除
  1. 3
      ml-agents/mlagents/trainers/__init__.py
  2. 11
      ml-agents/mlagents/trainers/bc/models.py
  3. 13
      ml-agents/mlagents/trainers/models.py
  4. 11
      ml-agents/mlagents/trainers/sac/models.py
  5. 13
      ml-agents/mlagents/trainers/tf.py

3
ml-agents/mlagents/trainers/__init__.py


from mlagents.trainers.tf import tf as tf
from mlagents.trainers.tf import tf as tf # noqa
from mlagents.trainers.tf import tf_flatten, tf_rnn, tf_variance_scaling # noqa

11
ml-agents/mlagents/trainers/bc/models.py


from mlagents.trainers import tf
if True: # TODO TF2
tf_variance_scaling = tf.initializers.variance_scaling
tf_flatten = tf.layers.flatten
else:
import tensorflow.contrib.layers as c_layers
tf_variance_scaling = c_layers.variance_scaling_initializer
tf_flatten = c_layers.flatten
from mlagents.trainers import tf, tf_variance_scaling
from mlagents.trainers.models import LearningModel

13
ml-agents/mlagents/trainers/models.py


from typing import Callable, List
import numpy as np
from mlagents.trainers import tf
if True: # TODO TF2
tf_variance_scaling = tf.initializers.variance_scaling
tf_flatten = tf.layers.flatten
tf_rnn = tf.nn.rnn_cell
else:
import tensorflow.contrib.layers as c_layers
tf_variance_scaling = c_layers.variance_scaling_initializer
tf_flatten = c_layers.flatten
tf_rnn = tf.contrib.rnn
from mlagents.trainers import tf, tf_variance_scaling, tf_rnn, tf_flatten
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.envs.brain import CameraResolution

11
ml-agents/mlagents/trainers/sac/models.py


import logging
import numpy as np
from mlagents.trainers import tf
from mlagents.trainers import tf, tf_variance_scaling
if True: # TODO TF2
tf_variance_scaling = tf.initializers.variance_scaling
tf_flatten = tf.layers.flatten
else:
import tensorflow.contrib.layers as c_layers
tf_variance_scaling = c_layers.variance_scaling_initializer
tf_flatten = c_layers.flatten
LOG_STD_MAX = 2
LOG_STD_MIN = -20

13
ml-agents/mlagents/trainers/tf.py


except ImportError:
import tensorflow as tf
is_tf2 = True
# TODO better version check, this will do for now though
is_tf2 = tf.__version__ == "2.0.0"
tf_variance_scaling = tf.initializers.variance_scaling
tf_flatten = tf.layers.flatten
tf_rnn = tf.nn.rnn_cell
else:
import tensorflow.contrib.layers as c_layers
tf_variance_scaling = c_layers.variance_scaling_initializer
tf_flatten = c_layers.flatten
tf_rnn = tf.contrib.rnn
正在加载...
取消
保存