浏览代码

Prevent tf.Session() from eating up all the GPU memory (#3219)

* Use soft placement and allow_growth for Session

* Move config generation to tf utils

* Re-add self.graph
/release-0.13.1
GitHub 5 年前
当前提交
d798b1cb
共有 4 个文件被更改,包括 23 次插入10 次删除
  1. 1
      ml-agents/mlagents/tf_utils/__init__.py
  2. 17
      ml-agents/mlagents/tf_utils/tf.py
  3. 12
      ml-agents/mlagents/trainers/tf_policy.py
  4. 3
      ml-agents/mlagents/trainers/trainer.py

1
ml-agents/mlagents/tf_utils/__init__.py


from mlagents.tf_utils.tf import tf as tf # noqa
from mlagents.tf_utils.tf import set_warnings_enabled # noqa
from mlagents.tf_utils.tf import generate_session_config # noqa

17
ml-agents/mlagents/tf_utils/tf.py


def set_warnings_enabled(is_enabled: bool) -> None:
"""
Enable or disable tensorflow warnings (notabley, this disables deprecation warnings.
Enable or disable tensorflow warnings (notably, this disables deprecation warnings.
def generate_session_config() -> tf.ConfigProto:
"""
Generate a ConfigProto to use for ML-Agents that doesn't consume all of the GPU memory
and allows for soft placement in the case of multi-GPU.
"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# For multi-GPU training, set allow_soft_placement to True to allow
# placing the operation into an alternative device automatically
# to prevent from exceptions if the device doesn't suppport the operation
# or the device does not exist
config.allow_soft_placement = True
return config

12
ml-agents/mlagents/trainers/tf_policy.py


import numpy as np
from mlagents.tf_utils import tf
from mlagents import tf_utils
from mlagents_envs.exception import UnityException
from mlagents.trainers.policy import Policy

self.model_path = trainer_parameters["model_path"]
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# For multi-GPU training, set allow_soft_placement to True to allow
# placing the operation into an alternative device automatically
# to prevent from exceptions if the device doesn't suppport the operation
# or the device does not exist
config.allow_soft_placement = True
self.sess = tf.Session(config=config, graph=self.graph)
self.sess = tf.Session(
config=tf_utils.generate_session_config(), graph=self.graph
)
self.saver = None
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]

3
ml-agents/mlagents/trainers/trainer.py


from typing import Dict, List, Deque, Any
from mlagents.tf_utils import tf
from mlagents import tf_utils
from collections import deque

:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
with tf.Session() as sess:
with tf.Session(config=tf_utils.generate_session_config()) as sess:
s_op = tf.summary.text(
key,
tf.convert_to_tensor(

正在加载...
取消
保存