浏览代码

EXPERIMENTAL horovod support

/develop-horovod
Jonathan Harper 5 年前
当前提交
3fc14963
共有 10 个文件被更改,包括 297 次插入26 次删除
  1. 82
      Dockerfile
  2. 21
      ml-agents-envs/mlagents/envs/environment.py
  3. 2
      ml-agents/mlagents/trainers/learn.py
  4. 2
      ml-agents/mlagents/trainers/ppo/models.py
  5. 8
      ml-agents/mlagents/trainers/sac/models.py
  6. 10
      ml-agents/mlagents/trainers/tf_policy.py
  7. 2
      ml-agents/setup.py
  8. 3
      .dockerignore
  9. 129
      Dockerfile-bak
  10. 64
      horovod-mlagents.yaml

82
Dockerfile


# Based off of python:3.6-slim, except that we are using ubuntu instead of debian.
FROM ubuntu:16.04
FROM nvidia/cudagl:9.0-devel-ubuntu16.04
# TensorFlow version is tightly coupled to CUDA and cuDNN so it should be selected carefully
ENV TENSORFLOW_VERSION=1.12.0
ENV PYTORCH_VERSION=1.1.0
ENV TORCHVISION_VERSION=0.2.2.post3
ENV CUDNN_VERSION=7.4.1.5-1+cuda9.0
ENV NCCL_VERSION=2.3.7-1+cuda9.0
# Set default shell to /bin/bash
SHELL ["/bin/bash", "-cu"]
# ensure local python is preferred over distribution python
ENV PATH /usr/local/bin:$PATH

rm -f get-pip.py
RUN apt-get update && apt-get -y upgrade
RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
build-essential \
cmake \
g++-4.8 \
git \
curl \
vim \
wget \
ca-certificates \
libcudnn7=${CUDNN_VERSION} \
libnccl2=${NCCL_VERSION} \
libnccl-dev=${NCCL_VERSION} \
libjpeg-dev \
libpng-dev \
librdmacm1 \
libibverbs1 \
libibverbs-dev \
xvfb
# xvfb is used to do CPU based rendering of Unity
RUN apt-get install -y xvfb
# Install TensorFlow
RUN pip install future typing
RUN pip install numpy \
tensorflow-gpu==${TENSORFLOW_VERSION} \
keras \
h5py
# Install Open MPI
RUN mkdir /tmp/openmpi && \
cd /tmp/openmpi && \
wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \
tar zxf openmpi-4.0.0.tar.gz && \
cd openmpi-4.0.0 && \
./configure --enable-orterun-prefix-by-default && \
make -j $(nproc) all && \
make install && \
ldconfig && \
rm -rf /tmp/openmpi
# Install Horovod, temporarily using CUDA stubs
RUN ldconfig /usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs && \
HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_WITH_TENSORFLOW=1 pip install --no-cache-dir horovod && \
ldconfig
# Install OpenSSH for MPI to communicate between containers
RUN apt-get install -y --no-install-recommends openssh-client openssh-server && \
mkdir -p /var/run/sshd
# Allow OpenSSH to talk to containers without asking for confirmation
RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \
echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \
mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config
# Install ml-agents-envs package locally
COPY ml-agents-envs /ml-agents-envs

WORKDIR /ml-agents
RUN pip install -e .
# port 5005 is the port used in in Editor training.
EXPOSE 5005
# setup google-cloud-sdk, which is used to copy files to gcs after the training is finished
RUN apt-get install --yes --no-install-recommends \
ca-certificates \
curl \
&& echo "deb http://packages.cloud.google.com/apt cloud-sdk-xenial main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \
&& curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - \
&& apt-get update \
&& apt-get install --yes google-cloud-sdk \
&& apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
ENTRYPOINT ["mlagents-learn"]
COPY unity-volume /unity-volume
RUN chmod +x /unity-volume/*.x86_64
CMD ["/bin/bash"]

21
ml-agents-envs/mlagents/envs/environment.py


atexit.register(self._close)
self.port = base_port + worker_id
self._buffer_size = 12000
self._version_ = "API-9"
self._version_ = "API-8"
self._loaded = (
False
) # If true, this means the environment was successfully loaded

"If the environment name is None, "
"the worker-id must be 0 in order to connect with the Editor."
)
if file_name is not None:
self.executable_launcher(file_name, docker_training, no_graphics, args)
else:
logger.info(
"Start training by pressing the Play button in the Unity Editor."
)
# if file_name is not None:
# self.executable_launcher(file_name, docker_training, no_graphics, args)
# else:
# logger.info(
# "Start training by pressing the Play button in the Unity Editor."
# )
self._loaded = True
rl_init_parameters_in = UnityRLInitializationInput(seed=seed)

raise
# TODO : think of a better way to expose the academyParameters
self._unity_version = aca_params.version
if self._unity_version != self._version_:
self._close()
raise UnityEnvironmentException(
"The API number is not compatible between Unity and python. Python API : {0}, Unity API : "
"{1}.\nPlease go to https://github.com/Unity-Technologies/ml-agents to download the latest version "
"of ML-Agents.".format(self._version_, self._unity_version)
)
self._n_agents: Dict[str, int] = {}
self._is_first_message = True
self._academy_name = aca_params.name

2
ml-agents/mlagents/trainers/learn.py


from mlagents.envs.exception import SamplerException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.subprocess_env_manager import SubprocessEnvManager
import horovod.tensorflow as hvd
class CommandLineOptions(NamedTuple):

options.sampler_file_path, env.reset_parameters, run_seed
)
hvd.init()
trainers = initialize_trainers(
trainer_config,
env.external_brains,

2
ml-agents/mlagents/trainers/ppo/models.py


import tensorflow as tf
from mlagents.trainers.models import LearningModel, EncoderType, LearningRateSchedule
import horovod.tensorflow as hvd
logger = logging.getLogger("mlagents.trainers")

def create_ppo_optimizer(self):
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.optimizer = hvd.DistributedOptimizer(self.optimizer)
self.grads = self.optimizer.compute_gradients(self.loss)
self.update_batch = self.optimizer.minimize(self.loss)

8
ml-agents/mlagents/trainers/sac/models.py


import tensorflow as tf
from mlagents.trainers.models import LearningModel, LearningRateSchedule, EncoderType
import tensorflow.contrib.layers as c_layers
import horovod.tensorflow as hvd
DISCRETE_TARGET_ENTROPY_SCALE = 0.2 # Roughly equal to e-greedy 0.05
CONTINUOUS_TARGET_ENTROPY_SCALE = 1.0 # TODO: Make these an optional hyperparam.
DISCRETE_TARGET_ENTROPY_SCALE = 0.1 # Roughly equal to e-greedy 0.05
CONTINUOUS_TARGET_ENTROPY_SCALE = 0.8 # TODO: Make these an optional hyperparam.
LOGGER = logging.getLogger("mlagents.trainers")

the policy, value, and entropy updates, as well as the target network update.
"""
policy_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
policy_optimizer = hvd.DistributedOptimizer(policy_optimizer)
entropy_optimizer = hvd.DistributedOptimizer(entropy_optimizer)
value_optimizer = hvd.DistributedOptimizer(value_optimizer)
self.target_update_op = [
tf.assign(target, (1 - self.tau) * target + self.tau * source)

10
ml-agents/mlagents/trainers/tf_policy.py


import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from mlagents.envs.exception import UnityException
from mlagents.envs.policy import Policy

self.use_continuous_act = brain.vector_action_space_type == "continuous"
self.model_path = trainer_parameters["model_path"]
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
self.graph = tf.Graph()
config.gpu_options.visible_device_list = str(hvd.local_rank())
config.gpu_options.allow_growth = True
# For multi-GPU training, set allow_soft_placement to True to allow
# placing the operation into an alternative device automatically

self.graph = tf.Graph()
self.sess = tf.Session(config=config, graph=self.graph)
self.saver = None
if self.use_recurrent:

self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
init = tf.global_variables_initializer()
self.sess.run(init)
self.sess.run(hvd.broadcast_global_variables(0))
def _load_graph(self):
with self.graph.as_default():

:param steps: The number of steps the model was trained for
:return:
"""
if hvd.rank() != 0:
return
with self.graph.as_default():
last_checkpoint = self.model_path + "/model-" + str(steps) + ".cptk"
self.saver.save(self.sess, last_checkpoint)

"""
Exports latest saved model to .nn format for Unity embedding.
"""
if hvd.rank() != 0:
return
with self.graph.as_default():
target_nodes = ",".join(self._process_graph())

2
ml-agents/setup.py


"Pillow>=4.2.1",
"protobuf>=3.6",
"pyyaml",
"tensorflow>=1.7,<2.0",
"horovod",
],
python_requires=">=3.6.1",
entry_points={"console_scripts": ["mlagents-learn=mlagents.trainers.learn:main"]},

3
.dockerignore


UnitySDK/
.git/
venv-harper/

129
Dockerfile-bak


# Based off of python:3.6-slim, except that we are using ubuntu instead of debian.
FROM ubuntu:16.04
# ensure local python is preferred over distribution python
ENV PATH /usr/local/bin:$PATH
# http://bugs.python.org/issue19846
# > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK.
ENV LANG C.UTF-8
# runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
libexpat1 \
libffi6 \
libgdbm3 \
libreadline6 \
libsqlite3-0 \
libssl1.0.0 \
&& rm -rf /var/lib/apt/lists/*
ENV GPG_KEY 0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D
ENV PYTHON_VERSION 3.6.4
RUN set -ex \
&& buildDeps=" \
dpkg-dev \
gcc \
libbz2-dev \
libc6-dev \
libexpat1-dev \
libffi-dev \
libgdbm-dev \
liblzma-dev \
libncursesw5-dev \
libreadline-dev \
libsqlite3-dev \
libssl-dev \
make \
tcl-dev \
tk-dev \
wget \
xz-utils \
zlib1g-dev \
# as of Stretch, "gpg" is no longer included by default
$(command -v gpg > /dev/null || echo 'gnupg dirmngr') \
" \
&& apt-get update && apt-get install -y $buildDeps --no-install-recommends && rm -rf /var/lib/apt/lists/* \
\
&& wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \
&& wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \
&& export GNUPGHOME="$(mktemp -d)" \
&& gpg --keyserver ha.pool.sks-keyservers.net --recv-keys "$GPG_KEY" \
&& gpg --batch --verify python.tar.xz.asc python.tar.xz \
&& rm -rf "$GNUPGHOME" python.tar.xz.asc \
&& mkdir -p /usr/src/python \
&& tar -xJC /usr/src/python --strip-components=1 -f python.tar.xz \
&& rm python.tar.xz \
\
&& cd /usr/src/python \
&& gnuArch="$(dpkg-architecture --query DEB_BUILD_GNU_TYPE)" \
&& ./configure \
--build="$gnuArch" \
--enable-loadable-sqlite-extensions \
--enable-shared \
--with-system-expat \
--with-system-ffi \
--without-ensurepip \
&& make -j "$(nproc)" \
&& make install \
&& ldconfig \
\
&& apt-get purge -y --auto-remove $buildDeps \
\
&& find /usr/local -depth \
\( \
\( -type d -a \( -name test -o -name tests \) \) \
-o \
\( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \
\) -exec rm -rf '{}' + \
&& rm -rf /usr/src/python
# make some useful symlinks that are expected to exist
RUN cd /usr/local/bin \
&& ln -s idle3 idle \
&& ln -s pydoc3 pydoc \
&& ln -s python3 python \
&& ln -s python3-config python-config
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
g++-4.8 \
git \
curl \
vim \
wget \
ca-certificates \
xvfb
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
ENV PYTHON_PIP_VERSION 9.0.3
RUN set -ex; \
wget -O get-pip.py 'https://bootstrap.pypa.io/get-pip.py'; \
\
python get-pip.py \
--disable-pip-version-check \
--no-cache-dir \
"pip==$PYTHON_PIP_VERSION" \
; \
pip --version; \
\
find /usr/local -depth \
\( \
\( -type d -a \( -name test -o -name tests \) \) \
-o \
\( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \
\) -exec rm -rf '{}' +; \
rm -f get-pip.py
COPY unity-volume /unity-volume
RUN chmod +x /unity-volume/*.x86_64
# port 5005 is the port used in in Editor training.
EXPOSE 5005
ENTRYPOINT ["mlagents-learn"]

64
horovod-mlagents.yaml


apiVersion: kubeflow.org/v1alpha2
kind: MPIJob
metadata:
name: mlagents-horovod-test
spec:
slotsPerWorker: 1
cleanPodPolicy: Running
mpiReplicaSpecs:
Launcher:
replicas: 1
template:
spec:
containers:
- image: gcr.io/unity-ml-agents-expts-test/mlagents-horovod:latest
name: mlagents-horovod-test
resources:
limits:
cpu: 4
command: ["/bin/sh", "-c"]
args: ["
mpirun --allow-run-as-root -np 16 -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x HOROVOD_AUTOTUNE=1 -x PATH -mca pml ob1 -mca btl ^openib mlagents-learn /unity-volume/trainer_config.yaml --run-id=snoopydist15-ppo --train --env=/unity-volume/Walker --num-envs=6;
mpirun --allow-run-as-root -np 16 -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x HOROVOD_AUTOTUNE=1 -x PATH -mca pml ob1 -mca btl ^openib gsutil cp -r models gs://ray-volume/horovod/;
mpirun --allow-run-as-root -np 16 -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x HOROVOD_AUTOTUNE=1 -x PATH -mca pml ob1 -mca btl ^openib gsutil cp -r summaries gs://ray-volume/horovod/;
"]
Worker:
replicas: 16
template:
spec:
containers:
- image: gcr.io/unity-ml-agents-expts-test/mlagents-horovod:latest
name: mlagents-horovod-test
resources:
limits:
cpu: 2
nvidia.com/gpu: 1
memory: 8G
- image: gcr.io/unity-ml-agents-expts-test/mlagents-horovod-env:latest
name: mlagents-horovod-test-env
resources:
requests:
cpu: 12
memory: 48G
limits:
cpu: 12
memory: 48G
command: ["/bin/sh", "-c"]
args: ["
sleep 95s;
cd /unity-volume;
(xvfb-run --auto-servernum --server-args='-screen 0 640x480x24' /unity-volume/SnoopyPop15Levels_10.x86_64 --port 5005 &);
(xvfb-run --auto-servernum --server-args='-screen 0 640x480x24' /unity-volume/SnoopyPop15Levels_10.x86_64 --port 5006 &);
(xvfb-run --auto-servernum --server-args='-screen 0 640x480x24' /unity-volume/SnoopyPop15Levels_10.x86_64 --port 5007 &);
(xvfb-run --auto-servernum --server-args='-screen 0 640x480x24' /unity-volume/SnoopyPop15Levels_10.x86_64 --port 5008 &);
(xvfb-run --auto-servernum --server-args='-screen 0 640x480x24' /unity-volume/SnoopyPop15Levels_10.x86_64 --port 5009 &);
xvfb-run --auto-servernum --server-args='-screen 0 640x480x24' /unity-volume/SnoopyPop15Levels_10.x86_64 --port 5010
"]
stdin: true
tty: true
securityContext:
privileged: true
capabilities:
add:
- SYS_ADMIN
正在加载...
取消
保存