浏览代码

Fix issue exporting graph with multi-GPU (#2573)

Our multi-GPU training had a regression such that freezing the
graph was broken.  This change fixes that issue by making a few
changes:

* Removes the top level "tower" variable scope added by multi-GPU
  so that the output nodes have correct names
* Removes the use of "freeze_graph" and replaces it with our own similar 
  functionality.
* Adds the "auto reuse" to network layers which require them
/develop-gpu-test
GitHub 5 年前
当前提交
36ed3c16
共有 3 个文件被更改,包括 12 次插入18 次删除
  1. 1
      ml-agents/mlagents/trainers/ppo/models.py
  2. 3
      ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
  3. 26
      ml-agents/mlagents/trainers/tf_policy.py

1
ml-agents/mlagents/trainers/ppo/models.py


self.act_size[0],
activation=None,
kernel_initializer=LearningModel.scaled_init(0.01),
reuse=tf.AUTO_REUSE,
)
self.log_sigma_sq = tf.get_variable(

3
ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py


self.devices = get_devices()
self.towers = []
with self.graph.as_default():
with tf.variable_scope(TOWER_SCOPE_NAME, reuse=tf.AUTO_REUSE):
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
for device in self.devices:
with tf.device(device):
self.towers.append(

)
self.towers[-1].create_ppo_optimizer()
self.model = self.towers[0]
avg_grads = self.average_gradients([t.grads for t in self.towers])
update_batch = self.model.optimizer.apply_gradients(avg_grads)

26
ml-agents/mlagents/trainers/tf_policy.py


from mlagents.envs.exception import UnityException
from mlagents.envs.policy import Policy
from mlagents.envs.action_info import ActionInfo
from tensorflow.python.tools import freeze_graph
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
from mlagents.envs.brain import BrainInfo

with self.graph.as_default():
target_nodes = ",".join(self._process_graph())
ckpt = tf.train.get_checkpoint_state(self.model_path)
freeze_graph.freeze_graph(
input_graph=self.model_path + "/raw_graph_def.pb",
input_binary=True,
input_checkpoint=ckpt.model_checkpoint_path,
output_node_names=target_nodes,
output_graph=(self.model_path + "/frozen_graph_def.pb"),
clear_devices=True,
initializer_nodes="",
input_saver="",
restore_op_name="save/restore_all",
filename_tensor_name="save/Const:0",
graph_def = self.graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
self.sess, graph_def, target_nodes.replace(" ", "").split(",")
tf2bc.convert(self.model_path + "/frozen_graph_def.pb", self.model_path + ".nn")
logger.info("Exported " + self.model_path + ".nn file")
frozen_graph_def_path = self.model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(output_graph_def.SerializeToString())
tf2bc.convert(frozen_graph_def_path, self.model_path + ".nn")
logger.info("Exported " + self.model_path + ".nn file")
def _process_graph(self):
"""

正在加载...
取消
保存