浏览代码

Add Multi-GPU implementation for PPO (#2288)

Add MultiGpuPPOPolicy class and command line options to run multi-GPU training
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
a9fe719c
共有 13 个文件被更改,包括 379 次插入51 次删除
  1. 3
      ml-agents/mlagents/trainers/learn.py
  2. 4
      ml-agents/mlagents/trainers/models.py
  3. 5
      ml-agents/mlagents/trainers/ppo/models.py
  4. 103
      ml-agents/mlagents/trainers/ppo/policy.py
  5. 20
      ml-agents/mlagents/trainers/ppo/trainer.py
  6. 2
      ml-agents/mlagents/trainers/tests/test_learn.py
  7. 2
      ml-agents/mlagents/trainers/tests/test_ppo.py
  8. 1
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  9. 13
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  10. 5
      ml-agents/mlagents/trainers/tf_policy.py
  11. 3
      ml-agents/mlagents/trainers/trainer_controller.py
  12. 140
      ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
  13. 129
      ml-agents/mlagents/trainers/tests/test_multigpu.py

3
ml-agents/mlagents/trainers/learn.py


lesson = int(run_options["--lesson"])
fast_simulation = not bool(run_options["--slow"])
no_graphics = run_options["--no-graphics"]
multi_gpu = run_options["--multi-gpu"]
trainer_config_path = run_options["<trainer-config-path>"]
sampler_file_path = (
run_options["--sampler"] if run_options["--sampler"] != "None" else None

lesson,
run_seed,
fast_simulation,
multi_gpu,
sampler_manager,
resampling_interval,
)

--docker-target-name=<dt> Docker volume to store training-specific files [default: None].
--no-graphics Whether to run the environment in no-graphics mode [default: False].
--debug Whether to run ML-Agents in debug mode with detailed logging [default: False].
--multi-gpu Whether to use multiple GPU training [default: False].
"""
options = docopt(_USAGE)

4
ml-agents/mlagents/trainers/models.py


hidden_policy,
self.act_size[0],
activation=None,
name="mu",
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01),
)

self.memory_out = tf.identity(memory_out, name="recurrent_out")
policy_branches = []
for size in self.act_size:
for i, size in enumerate(self.act_size):
policy_branches.append(
tf.layers.dense(
hidden,

name="policy_branch_" + str(i),
kernel_initializer=c_layers.variance_scaling_initializer(
factor=0.01
),

5
ml-agents/mlagents/trainers/ppo/models.py


)
def create_ppo_optimizer(self):
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.update_batch = optimizer.minimize(self.loss)
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.grads = self.optimizer.compute_gradients(self.loss)
self.update_batch = self.optimizer.minimize(self.loss)

103
ml-agents/mlagents/trainers/ppo/policy.py


reward_signal_configs = trainer_params["reward_signals"]
self.create_model(brain, trainer_params, reward_signal_configs, seed)
self.model = PPOModel(
brain,
lr=float(trainer_params["learning_rate"]),
h_size=int(trainer_params["hidden_units"]),
epsilon=float(trainer_params["epsilon"]),
beta=float(trainer_params["beta"]),
max_step=float(trainer_params["max_steps"]),
normalize=trainer_params["normalize"],
use_recurrent=trainer_params["use_recurrent"],
num_layers=int(trainer_params["num_layers"]),
m_size=self.m_size,
seed=seed,
stream_names=list(reward_signal_configs.keys()),
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
)
self.model.create_ppo_optimizer()
# Create reward signals
for reward_signal, config in reward_signal_configs.items():
self.reward_signals[reward_signal] = create_reward_signal(

"update_batch": self.model.update_batch,
}
def create_model(self, brain, trainer_params, reward_signal_configs, seed):
"""
Create PPO model
:param brain: Assigned Brain object.
:param trainer_params: Defined training parameters.
:param reward_signal_configs: Reward signal config
:param seed: Random seed.
"""
with self.graph.as_default():
self.model = PPOModel(
brain=brain,
lr=float(trainer_params["learning_rate"]),
h_size=int(trainer_params["hidden_units"]),
epsilon=float(trainer_params["epsilon"]),
beta=float(trainer_params["beta"]),
max_step=float(trainer_params["max_steps"]),
normalize=trainer_params["normalize"],
use_recurrent=trainer_params["use_recurrent"],
num_layers=int(trainer_params["num_layers"]),
m_size=self.m_size,
seed=seed,
stream_names=list(reward_signal_configs.keys()),
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
)
self.model.create_ppo_optimizer()
@timed
def evaluate(self, brain_info):
"""

:param mini_batch: Experience batch.
:return: Output from update process.
"""
feed_dict = self.construct_feed_dict(self.model, mini_batch, num_sequences)
run_out = self._execute_model(feed_dict, self.update_dict)
return run_out
def construct_feed_dict(self, model, mini_batch, num_sequences):
self.model.batch_size: num_sequences,
self.model.sequence_length: self.sequence_length,
self.model.mask_input: mini_batch["masks"].flatten(),
self.model.advantage: mini_batch["advantages"].reshape([-1, 1]),
self.model.all_old_log_probs: mini_batch["action_probs"].reshape(
[-1, sum(self.model.act_size)]
model.batch_size: num_sequences,
model.sequence_length: self.sequence_length,
model.mask_input: mini_batch["masks"].flatten(),
model.advantage: mini_batch["advantages"].reshape([-1, 1]),
model.all_old_log_probs: mini_batch["action_probs"].reshape(
[-1, sum(model.act_size)]
feed_dict[self.model.returns_holders[name]] = mini_batch[
feed_dict[model.returns_holders[name]] = mini_batch[
feed_dict[self.model.old_values[name]] = mini_batch[
feed_dict[model.old_values[name]] = mini_batch[
feed_dict[self.model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, self.model.act_size[0]]
feed_dict[model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, model.act_size[0]]
feed_dict[self.model.epsilon] = mini_batch["random_normal_epsilon"].reshape(
[-1, self.model.act_size[0]]
feed_dict[model.epsilon] = mini_batch["random_normal_epsilon"].reshape(
[-1, model.act_size[0]]
feed_dict[self.model.action_holder] = mini_batch["actions"].reshape(
[-1, len(self.model.act_size)]
feed_dict[model.action_holder] = mini_batch["actions"].reshape(
[-1, len(model.act_size)]
feed_dict[self.model.prev_action] = mini_batch["prev_action"].reshape(
[-1, len(self.model.act_size)]
feed_dict[model.prev_action] = mini_batch["prev_action"].reshape(
[-1, len(model.act_size)]
feed_dict[self.model.action_masks] = mini_batch["action_mask"].reshape(
feed_dict[model.action_masks] = mini_batch["action_mask"].reshape(
feed_dict[self.model.vector_in] = mini_batch["vector_obs"].reshape(
feed_dict[model.vector_in] = mini_batch["vector_obs"].reshape(
if self.model.vis_obs_size > 0:
for i, _ in enumerate(self.model.visual_in):
if model.vis_obs_size > 0:
for i, _ in enumerate(model.visual_in):
feed_dict[self.model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
feed_dict[model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
feed_dict[self.model.visual_in[i]] = _obs
feed_dict[model.visual_in[i]] = _obs
feed_dict[self.model.memory_in] = mem_in
run_out = self._execute_model(feed_dict, self.update_dict)
return run_out
feed_dict[model.memory_in] = mem_in
return feed_dict
def get_value_estimates(
self, brain_info: BrainInfo, idx: int, done: bool

20
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents.envs import AllBrainInfo, BrainInfo
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy, get_devices
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.envs.action_info import ActionInfoOutputs

"""The PPOTrainer is an implementation of the PPO algorithm."""
def __init__(
self, brain, reward_buff_cap, trainer_parameters, training, load, seed, run_id
self,
brain,
reward_buff_cap,
trainer_parameters,
training,
load,
seed,
run_id,
multi_gpu,
):
"""
Responsible for collecting experiences and training PPO model.

)
self.step = 0
self.policy = PPOPolicy(seed, brain, trainer_parameters, self.is_training, load)
if multi_gpu and len(get_devices()) > 1:
self.policy = MultiGpuPPOPolicy(
seed, brain, trainer_parameters, self.is_training, load
)
else:
self.policy = PPOPolicy(
seed, brain, trainer_parameters, self.is_training, load
)
stats = defaultdict(list)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward

2
ml-agents/mlagents/trainers/tests/test_learn.py


"--no-graphics": False,
"<trainer-config-path>": "basic_path",
"--debug": False,
"--multi-gpu": False,
"--sampler": None,
}

0,
0,
True,
False,
sampler_manager_mock.return_value,
None,
)

2
ml-agents/mlagents/trainers/tests/test_ppo.py


}
brain_params = BrainParameters("test_brain", 1, 1, [], [2], [], 0)
trainer = PPOTrainer(brain_params, 0, trainer_params, True, False, 0, "0")
trainer = PPOTrainer(brain_params, 0, trainer_params, True, False, 0, "0", False)
policy_mock = mock.Mock()
step_count = 10
policy_mock.increment_step = mock.Mock(return_value=step_count)

1
ml-agents/mlagents/trainers/tests/test_simple_rl.py


lesson=None,
training_seed=1337,
fast_simulation=True,
multi_gpu=False,
sampler_manager=SamplerManager(None),
resampling_interval=None,
)

13
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


lesson=None,
training_seed=99,
fast_simulation=True,
multi_gpu=False,
sampler_manager=SamplerManager(None),
resampling_interval=None,
)

None,
seed,
True,
False,
SamplerManager(None),
None,
)

external_brains = {"testbrain": expected_brain_params}
def mock_constructor(
self, brain, reward_buff_cap, trainer_parameters, training, load, seed, run_id
self,
brain,
reward_buff_cap,
trainer_parameters,
training,
load,
seed,
run_id,
multi_gpu,
):
self.trainer_metrics = TrainerMetrics("", "")
assert brain == expected_brain_params

assert load == tc.load_model
assert seed == tc.seed
assert run_id == tc.run_id
assert multi_gpu == tc.multi_gpu
with patch.object(PPOTrainer, "__init__", mock_constructor):
tc.initialize_trainers(input_config, external_brains)

5
ml-agents/mlagents/trainers/tf_policy.py


self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# For multi-GPU training, set allow_soft_placement to True to allow
# placing the operation into an alternative device automatically
# to prevent from exceptions if the device doesn't suppport the operation
# or the device does not exist
config.allow_soft_placement = True
self.sess = tf.Session(config=config, graph=self.graph)
self.saver = None
if self.use_recurrent:

3
ml-agents/mlagents/trainers/trainer_controller.py


lesson: Optional[int],
training_seed: int,
fast_simulation: bool,
multi_gpu: bool,
sampler_manager: SamplerManager,
resampling_interval: Optional[int],
):

self.seed = training_seed
self.training_start_time = time()
self.fast_simulation = fast_simulation
self.multi_gpu = multi_gpu
np.random.seed(self.seed)
tf.set_random_seed(self.seed)
self.sampler_manager = sampler_manager

load=self.load_model,
seed=self.seed,
run_id=self.run_id,
multi_gpu=self.multi_gpu,
)
self.trainer_metrics[brain_name] = self.trainers[
brain_name

140
ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py


import logging
import numpy as np
import tensorflow as tf
from tensorflow.python.client import device_lib
from mlagents.envs.timers import timed
from mlagents.trainers.models import EncoderType
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
from mlagents.trainers.components.bc.module import BCModule
# Variable scope in which created variables will be placed under
TOWER_SCOPE_NAME = "tower"
logger = logging.getLogger("mlagents.trainers")
class MultiGpuPPOPolicy(PPOPolicy):
def __init__(self, seed, brain, trainer_params, is_training, load):
"""
Policy for Proximal Policy Optimization Networks with multi-GPU training
:param seed: Random seed.
:param brain: Assigned Brain object.
:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
"""
super().__init__(seed, brain, trainer_params, is_training, load)
with self.graph.as_default():
avg_grads = self.average_gradients([t.grads for t in self.towers])
self.update_batch = self.model.optimizer.apply_gradients(avg_grads)
self.update_dict = {"update_batch": self.update_batch}
self.update_dict.update(
{
"value_loss_" + str(i): self.towers[i].value_loss
for i in range(len(self.towers))
}
)
self.update_dict.update(
{
"policy_loss_" + str(i): self.towers[i].policy_loss
for i in range(len(self.towers))
}
)
def create_model(self, brain, trainer_params, reward_signal_configs, seed):
"""
Create PPO models, one on each device
:param brain: Assigned Brain object.
:param trainer_params: Defined training parameters.
:param reward_signal_configs: Reward signal config
:param seed: Random seed.
"""
self.devices = get_devices()
self.towers = []
with self.graph.as_default():
with tf.variable_scope(TOWER_SCOPE_NAME, reuse=tf.AUTO_REUSE):
for device in self.devices:
with tf.device(device):
self.towers.append(
PPOModel(
brain=brain,
lr=float(trainer_params["learning_rate"]),
h_size=int(trainer_params["hidden_units"]),
epsilon=float(trainer_params["epsilon"]),
beta=float(trainer_params["beta"]),
max_step=float(trainer_params["max_steps"]),
normalize=trainer_params["normalize"],
use_recurrent=trainer_params["use_recurrent"],
num_layers=int(trainer_params["num_layers"]),
m_size=self.m_size,
seed=seed,
stream_names=list(reward_signal_configs.keys()),
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
)
)
self.towers[-1].create_ppo_optimizer()
self.model = self.towers[0]
@timed
def update(self, mini_batch, num_sequences):
"""
Updates model using buffer.
:param n_sequences: Number of trajectories in batch.
:param mini_batch: Experience batch.
:return: Output from update process.
"""
feed_dict = {}
device_batch_size = num_sequences // len(self.devices)
device_batches = []
for i in range(len(self.devices)):
device_batches.append(
{k: v[i : i + device_batch_size] for (k, v) in mini_batch.items()}
)
for batch, tower in zip(device_batches, self.towers):
feed_dict.update(self.construct_feed_dict(tower, batch, num_sequences))
out = self._execute_model(feed_dict, self.update_dict)
run_out = {}
run_out["value_loss"] = np.mean(
[out["value_loss_" + str(i)] for i in range(len(self.towers))]
)
run_out["policy_loss"] = np.mean(
[out["policy_loss_" + str(i)] for i in range(len(self.towers))]
)
run_out["update_batch"] = out["update_batch"]
return run_out
def average_gradients(self, tower_grads):
"""
Average gradients from all towers
:param tower_grads: Gradients from all towers
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
grads = [g for g, _ in grad_and_vars if g is not None]
if not grads:
continue
avg_grad = tf.reduce_mean(tf.stack(grads), 0)
var = grad_and_vars[0][1]
average_grads.append((avg_grad, var))
return average_grads
def get_devices():
"""
Get all available GPU devices
"""
local_device_protos = device_lib.list_local_devices()
devices = [x.name for x in local_device_protos if x.device_type == "GPU"]
return devices

129
ml-agents/mlagents/trainers/tests/test_multigpu.py


import unittest.mock as mock
import pytest
import numpy as np
import tensorflow as tf
import yaml
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy, get_devices
from mlagents.envs import UnityEnvironment, BrainParameters
from mlagents.envs.mock_communicator import MockCommunicator
from mlagents.trainers.tests.mock_brain import create_mock_brainparams
@pytest.fixture
def dummy_config():
return yaml.safe_load(
"""
trainer: ppo
batch_size: 32
beta: 5.0e-3
buffer_size: 512
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
max_steps: 5.0e4
normalize: true
num_epoch: 5
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
use_recurrent: false
memory_size: 8
curiosity_strength: 0.0
curiosity_enc_size: 1
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices")
def test_create_model(mock_get_devices, dummy_config):
tf.reset_default_graph()
mock_get_devices.return_value = [
"/device:GPU:0",
"/device:GPU:1",
"/device:GPU:2",
"/device:GPU:3",
]
trainer_parameters = dummy_config
trainer_parameters["model_path"] = ""
trainer_parameters["keep_checkpoints"] = 3
brain = create_mock_brainparams()
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False)
assert len(policy.towers) == len(mock_get_devices.return_value)
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices")
def test_average_gradients(mock_get_devices, dummy_config):
tf.reset_default_graph()
mock_get_devices.return_value = [
"/device:GPU:0",
"/device:GPU:1",
"/device:GPU:2",
"/device:GPU:3",
]
trainer_parameters = dummy_config
trainer_parameters["model_path"] = ""
trainer_parameters["keep_checkpoints"] = 3
brain = create_mock_brainparams()
with tf.Session() as sess:
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False)
var = tf.Variable(0)
tower_grads = [
[(tf.constant(0.1), var)],
[(tf.constant(0.2), var)],
[(tf.constant(0.3), var)],
[(tf.constant(0.4), var)],
]
avg_grads = policy.average_gradients(tower_grads)
init = tf.global_variables_initializer()
sess.run(init)
run_out = sess.run(avg_grads)
assert run_out == [(0.25, 0)]
@mock.patch("mlagents.trainers.tf_policy.TFPolicy._execute_model")
@mock.patch("mlagents.trainers.ppo.policy.PPOPolicy.construct_feed_dict")
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices")
def test_update(
mock_get_devices, mock_construct_feed_dict, mock_execute_model, dummy_config
):
tf.reset_default_graph()
mock_get_devices.return_value = ["/device:GPU:0", "/device:GPU:1"]
mock_construct_feed_dict.return_value = {}
mock_execute_model.return_value = {
"value_loss_0": 0.1,
"value_loss_1": 0.3,
"policy_loss_0": 0.5,
"policy_loss_1": 0.7,
"update_batch": None,
}
trainer_parameters = dummy_config
trainer_parameters["model_path"] = ""
trainer_parameters["keep_checkpoints"] = 3
brain = create_mock_brainparams()
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False)
mock_mini_batch = mock.Mock()
mock_mini_batch.items.return_value = [("action", [1, 2]), ("value", [3, 4])]
run_out = policy.update(mock_mini_batch, 1)
assert mock_mini_batch.items.call_count == len(mock_get_devices.return_value)
assert mock_construct_feed_dict.call_count == len(mock_get_devices.return_value)
assert run_out["value_loss"] == 0.2
assert run_out["policy_loss"] == 0.6
if __name__ == "__main__":
pytest.main()
正在加载...
取消
保存