浏览代码

Fix graph init in ghost trainer

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
dcbb90e1
共有 4 个文件被更改,包括 12 次插入4 次删除
  1. 3
      ml-agents/mlagents/trainers/ghost/trainer.py
  2. 9
      ml-agents/mlagents/trainers/optimizer.py
  3. 2
      ml-agents/mlagents/trainers/ppo/optimizer.py
  4. 2
      ml-agents/mlagents/trainers/sac/optimizer.py

3
ml-agents/mlagents/trainers/ghost/trainer.py


def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
self.policies[name_behavior_id] = policy
policy.create_tf_graph()
# First policy encountered
if not self.learning_behavior_name:

self._save_snapshot(policy) # Need to save after trainer initializes policy
self.learning_behavior_name = name_behavior_id
else:
# Normally Optimizer initializes policy. Do it here instead.
policy.create_tf_graph()
# for saving/swapping snapshots
policy.init_load_weights()

9
ml-agents/mlagents/trainers/optimizer.py


def create_tf_optimizer(self, learning_rate, name="Adam"):
return tf.train.AdamOptimizer(learning_rate=learning_rate, name=name)
def _create_policy_tf_graph_if_needed(self, policy):
"""
Creates the policy TF graph. If already created, don't do anything.
"""
with policy.graph.as_default():
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if len(_vars) == 0:
policy.create_tf_graph()
def _execute_model(self, feed_dict, out_dict):
"""
Executes model.

2
ml-agents/mlagents/trainers/ppo/optimizer.py


:param trainer_params: Trainer parameters dictionary that specifies the properties of the trainer.
"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
policy.create_tf_graph()
self._create_policy_tf_graph_if_needed(policy)
with policy.graph.as_default():
with tf.variable_scope("optimizer/"):

2
ml-agents/mlagents/trainers/sac/optimizer.py


:param m_size: Size of brain memory.
"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
policy.create_tf_graph()
self._create_policy_tf_graph_if_needed(policy)
with policy.graph.as_default():
with tf.variable_scope(""):

正在加载...
取消
保存