浏览代码
Add Multi-GPU implementation for PPO (#2288)
Add Multi-GPU implementation for PPO (#2288)
Add MultiGpuPPOPolicy class and command line options to run multi-GPU training/develop-generalizationTraining-TrainerController
GitHub
5 年前
当前提交
a9fe719c
共有 13 个文件被更改,包括 379 次插入 和 51 次删除
-
3ml-agents/mlagents/trainers/learn.py
-
4ml-agents/mlagents/trainers/models.py
-
5ml-agents/mlagents/trainers/ppo/models.py
-
103ml-agents/mlagents/trainers/ppo/policy.py
-
20ml-agents/mlagents/trainers/ppo/trainer.py
-
2ml-agents/mlagents/trainers/tests/test_learn.py
-
2ml-agents/mlagents/trainers/tests/test_ppo.py
-
1ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
13ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
5ml-agents/mlagents/trainers/tf_policy.py
-
3ml-agents/mlagents/trainers/trainer_controller.py
-
140ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
-
129ml-agents/mlagents/trainers/tests/test_multigpu.py
|
|||
import logging |
|||
import numpy as np |
|||
|
|||
import tensorflow as tf |
|||
from tensorflow.python.client import device_lib |
|||
from mlagents.envs.timers import timed |
|||
from mlagents.trainers.models import EncoderType |
|||
from mlagents.trainers.ppo.policy import PPOPolicy |
|||
from mlagents.trainers.ppo.models import PPOModel |
|||
from mlagents.trainers.components.reward_signals.reward_signal_factory import ( |
|||
create_reward_signal, |
|||
) |
|||
from mlagents.trainers.components.bc.module import BCModule |
|||
|
|||
# Variable scope in which created variables will be placed under |
|||
TOWER_SCOPE_NAME = "tower" |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class MultiGpuPPOPolicy(PPOPolicy): |
|||
def __init__(self, seed, brain, trainer_params, is_training, load): |
|||
""" |
|||
Policy for Proximal Policy Optimization Networks with multi-GPU training |
|||
:param seed: Random seed. |
|||
:param brain: Assigned Brain object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param is_training: Whether the model should be trained. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
""" |
|||
super().__init__(seed, brain, trainer_params, is_training, load) |
|||
|
|||
with self.graph.as_default(): |
|||
avg_grads = self.average_gradients([t.grads for t in self.towers]) |
|||
self.update_batch = self.model.optimizer.apply_gradients(avg_grads) |
|||
|
|||
self.update_dict = {"update_batch": self.update_batch} |
|||
self.update_dict.update( |
|||
{ |
|||
"value_loss_" + str(i): self.towers[i].value_loss |
|||
for i in range(len(self.towers)) |
|||
} |
|||
) |
|||
self.update_dict.update( |
|||
{ |
|||
"policy_loss_" + str(i): self.towers[i].policy_loss |
|||
for i in range(len(self.towers)) |
|||
} |
|||
) |
|||
|
|||
def create_model(self, brain, trainer_params, reward_signal_configs, seed): |
|||
""" |
|||
Create PPO models, one on each device |
|||
:param brain: Assigned Brain object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param reward_signal_configs: Reward signal config |
|||
:param seed: Random seed. |
|||
""" |
|||
self.devices = get_devices() |
|||
self.towers = [] |
|||
with self.graph.as_default(): |
|||
with tf.variable_scope(TOWER_SCOPE_NAME, reuse=tf.AUTO_REUSE): |
|||
for device in self.devices: |
|||
with tf.device(device): |
|||
self.towers.append( |
|||
PPOModel( |
|||
brain=brain, |
|||
lr=float(trainer_params["learning_rate"]), |
|||
h_size=int(trainer_params["hidden_units"]), |
|||
epsilon=float(trainer_params["epsilon"]), |
|||
beta=float(trainer_params["beta"]), |
|||
max_step=float(trainer_params["max_steps"]), |
|||
normalize=trainer_params["normalize"], |
|||
use_recurrent=trainer_params["use_recurrent"], |
|||
num_layers=int(trainer_params["num_layers"]), |
|||
m_size=self.m_size, |
|||
seed=seed, |
|||
stream_names=list(reward_signal_configs.keys()), |
|||
vis_encode_type=EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
), |
|||
) |
|||
) |
|||
self.towers[-1].create_ppo_optimizer() |
|||
self.model = self.towers[0] |
|||
|
|||
@timed |
|||
def update(self, mini_batch, num_sequences): |
|||
""" |
|||
Updates model using buffer. |
|||
:param n_sequences: Number of trajectories in batch. |
|||
:param mini_batch: Experience batch. |
|||
:return: Output from update process. |
|||
""" |
|||
feed_dict = {} |
|||
|
|||
device_batch_size = num_sequences // len(self.devices) |
|||
device_batches = [] |
|||
for i in range(len(self.devices)): |
|||
device_batches.append( |
|||
{k: v[i : i + device_batch_size] for (k, v) in mini_batch.items()} |
|||
) |
|||
|
|||
for batch, tower in zip(device_batches, self.towers): |
|||
feed_dict.update(self.construct_feed_dict(tower, batch, num_sequences)) |
|||
|
|||
out = self._execute_model(feed_dict, self.update_dict) |
|||
run_out = {} |
|||
run_out["value_loss"] = np.mean( |
|||
[out["value_loss_" + str(i)] for i in range(len(self.towers))] |
|||
) |
|||
run_out["policy_loss"] = np.mean( |
|||
[out["policy_loss_" + str(i)] for i in range(len(self.towers))] |
|||
) |
|||
run_out["update_batch"] = out["update_batch"] |
|||
return run_out |
|||
|
|||
def average_gradients(self, tower_grads): |
|||
""" |
|||
Average gradients from all towers |
|||
:param tower_grads: Gradients from all towers |
|||
""" |
|||
average_grads = [] |
|||
for grad_and_vars in zip(*tower_grads): |
|||
grads = [g for g, _ in grad_and_vars if g is not None] |
|||
if not grads: |
|||
continue |
|||
avg_grad = tf.reduce_mean(tf.stack(grads), 0) |
|||
var = grad_and_vars[0][1] |
|||
average_grads.append((avg_grad, var)) |
|||
return average_grads |
|||
|
|||
|
|||
def get_devices(): |
|||
""" |
|||
Get all available GPU devices |
|||
""" |
|||
local_device_protos = device_lib.list_local_devices() |
|||
devices = [x.name for x in local_device_protos if x.device_type == "GPU"] |
|||
return devices |
|
|||
import unittest.mock as mock |
|||
import pytest |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
import yaml |
|||
|
|||
from mlagents.trainers.ppo.trainer import PPOTrainer |
|||
from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy, get_devices |
|||
from mlagents.envs import UnityEnvironment, BrainParameters |
|||
from mlagents.envs.mock_communicator import MockCommunicator |
|||
from mlagents.trainers.tests.mock_brain import create_mock_brainparams |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return yaml.safe_load( |
|||
""" |
|||
trainer: ppo |
|||
batch_size: 32 |
|||
beta: 5.0e-3 |
|||
buffer_size: 512 |
|||
epsilon: 0.2 |
|||
hidden_units: 128 |
|||
lambd: 0.95 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 5.0e4 |
|||
normalize: true |
|||
num_epoch: 5 |
|||
num_layers: 2 |
|||
time_horizon: 64 |
|||
sequence_length: 64 |
|||
summary_freq: 1000 |
|||
use_recurrent: false |
|||
memory_size: 8 |
|||
curiosity_strength: 0.0 |
|||
curiosity_enc_size: 1 |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
""" |
|||
) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices") |
|||
def test_create_model(mock_get_devices, dummy_config): |
|||
tf.reset_default_graph() |
|||
mock_get_devices.return_value = [ |
|||
"/device:GPU:0", |
|||
"/device:GPU:1", |
|||
"/device:GPU:2", |
|||
"/device:GPU:3", |
|||
] |
|||
|
|||
trainer_parameters = dummy_config |
|||
trainer_parameters["model_path"] = "" |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
brain = create_mock_brainparams() |
|||
|
|||
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False) |
|||
assert len(policy.towers) == len(mock_get_devices.return_value) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices") |
|||
def test_average_gradients(mock_get_devices, dummy_config): |
|||
tf.reset_default_graph() |
|||
mock_get_devices.return_value = [ |
|||
"/device:GPU:0", |
|||
"/device:GPU:1", |
|||
"/device:GPU:2", |
|||
"/device:GPU:3", |
|||
] |
|||
|
|||
trainer_parameters = dummy_config |
|||
trainer_parameters["model_path"] = "" |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
brain = create_mock_brainparams() |
|||
with tf.Session() as sess: |
|||
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False) |
|||
var = tf.Variable(0) |
|||
tower_grads = [ |
|||
[(tf.constant(0.1), var)], |
|||
[(tf.constant(0.2), var)], |
|||
[(tf.constant(0.3), var)], |
|||
[(tf.constant(0.4), var)], |
|||
] |
|||
avg_grads = policy.average_gradients(tower_grads) |
|||
|
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
run_out = sess.run(avg_grads) |
|||
assert run_out == [(0.25, 0)] |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.tf_policy.TFPolicy._execute_model") |
|||
@mock.patch("mlagents.trainers.ppo.policy.PPOPolicy.construct_feed_dict") |
|||
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices") |
|||
def test_update( |
|||
mock_get_devices, mock_construct_feed_dict, mock_execute_model, dummy_config |
|||
): |
|||
tf.reset_default_graph() |
|||
mock_get_devices.return_value = ["/device:GPU:0", "/device:GPU:1"] |
|||
mock_construct_feed_dict.return_value = {} |
|||
mock_execute_model.return_value = { |
|||
"value_loss_0": 0.1, |
|||
"value_loss_1": 0.3, |
|||
"policy_loss_0": 0.5, |
|||
"policy_loss_1": 0.7, |
|||
"update_batch": None, |
|||
} |
|||
|
|||
trainer_parameters = dummy_config |
|||
trainer_parameters["model_path"] = "" |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
brain = create_mock_brainparams() |
|||
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False) |
|||
mock_mini_batch = mock.Mock() |
|||
mock_mini_batch.items.return_value = [("action", [1, 2]), ("value", [3, 4])] |
|||
run_out = policy.update(mock_mini_batch, 1) |
|||
|
|||
assert mock_mini_batch.items.call_count == len(mock_get_devices.return_value) |
|||
assert mock_construct_feed_dict.call_count == len(mock_get_devices.return_value) |
|||
assert run_out["value_loss"] == 0.2 |
|||
assert run_out["policy_loss"] == 0.6 |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
撰写
预览
正在加载...
取消
保存
Reference in new issue