浏览代码

Make TF graph seed deterministic

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
63463bd1
共有 2 个文件被更改,包括 2 次插入1 次删除
  1. 1
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 2
      ml-agents/mlagents/trainers/tf_policy.py

1
ml-agents/mlagents/trainers/common/nn_policy.py


Builds the tensorflow graph needed for this policy.
"""
with self.graph.as_default():
tf.set_random_seed(self.seed)
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if len(_vars) > 0:
# We assume the first thing created in the graph is the Policy. If

2
ml-agents/mlagents/trainers/tf_policy.py


self.sess = tf.Session(
config=tf_utils.generate_session_config(), graph=self.graph
)
tf.set_random_seed(seed)
self.seed = seed
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]
self.sequence_length = trainer_parameters["sequence_length"]

正在加载...
取消
保存