|
|
|
|
|
|
def create_tf_optimizer(self, learning_rate, name="Adam"): |
|
|
|
return tf.train.AdamOptimizer(learning_rate=learning_rate, name=name) |
|
|
|
|
|
|
|
def _create_policy_tf_graph_if_needed(self, policy): |
|
|
|
""" |
|
|
|
Creates the policy TF graph. If already created, don't do anything. |
|
|
|
""" |
|
|
|
with policy.graph.as_default(): |
|
|
|
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
|
|
|
if len(_vars) == 0: |
|
|
|
policy.create_tf_graph() |
|
|
|
|
|
|
|
def _execute_model(self, feed_dict, out_dict): |
|
|
|
""" |
|
|
|
Executes model. |
|
|
|