Builds the tensorflow graph needed for this policy.
"""
with self.graph.as_default():
tf.set_random_seed(self.seed)
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if len(_vars) > 0:
# We assume the first thing created in the graph is the Policy. If
self.sess = tf.Session(
config=tf_utils.generate_session_config(), graph=self.graph
)
tf.set_random_seed(seed)
self.seed = seed
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]
self.sequence_length = trainer_parameters["sequence_length"]