浏览代码

first commit

/distributed-training
Anupam Bhatnagar 5 年前
当前提交
001fce2a
共有 6 个文件被更改,包括 28 次插入1 次删除
  1. 3
      ml-agents/mlagents/tf_utils/tf.py
  2. 6
      ml-agents/mlagents/trainers/learn.py
  3. 5
      ml-agents/mlagents/trainers/policy/tf_policy.py
  4. 5
      ml-agents/mlagents/trainers/ppo/optimizer.py
  5. 6
      ml-agents/mlagents/trainers/trainer/trainer.py
  6. 4
      ml-agents/mlagents/trainers/trainer_controller.py

3
ml-agents/mlagents/tf_utils/tf.py


# Everywhere else is caught by the banned-modules setting for flake8
import tensorflow as tf # noqa I201
from distutils.version import LooseVersion
import horovod.tensorflow as hvd
# LooseVersion handles things "1.2.3a" or "4.5.6-rc7" fairly sensibly.

"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
# For multi-GPU training, set allow_soft_placement to True to allow
# placing the operation into an alternative device automatically
# to prevent from exceptions if the device doesn't suppport the operation

6
ml-agents/mlagents/trainers/learn.py


from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
<<<<<<< HEAD
=======
from mlagents.logging_util import create_logger
import horovod.tensorflow as hvd
>>>>>>> first commit
def _create_parser():

sampler_manager, resampling_interval = create_sampler_manager(
options.sampler_config, run_seed
)
hvd.init()
trainer_factory = TrainerFactory(
options.trainer_config,
summaries_dir,

5
ml-agents/mlagents/trainers/policy/tf_policy.py


from mlagents.trainers.brain_conversion_utils import get_global_agent_id
from mlagents_envs.base_env import BatchedStepResult
from mlagents.trainers.models import ModelUtils
import horovod.tensorflow as hvd
logger = get_logger(__name__)

self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
init = tf.global_variables_initializer()
self.sess.run(init)
self.sess.run(hvd.broadcast_global_variables(0))
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
with self.graph.as_default():

:param steps: The number of steps the model was trained for
:return:
"""
if hvd.rank() != 0:
return
with self.graph.as_default():
last_checkpoint = self.model_path + "/model-" + str(steps) + ".ckpt"
self.saver.save(self.sess, last_checkpoint)

5
ml-agents/mlagents/trainers/ppo/optimizer.py


from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer
import horovod.tensorflow as hvd
class PPOOptimizer(TFOptimizer):

)
def _create_ppo_optimizer_ops(self):
self.tf_optimizer = self.create_optimizer_op(self.learning_rate)
self.tf_optimizer = self.create_optimizer_op(self.learning_rate * hvd.size())
if hvd is not None:
self.tf_optimizer = hvd.DistributedOptimizer(self.tf_optimizer)
self.grads = self.tf_optimizer.compute_gradients(self.loss)
self.update_batch = self.tf_optimizer.minimize(self.loss)

6
ml-agents/mlagents/trainers/trainer/trainer.py


from mlagents.trainers.policy import Policy
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
import horovod.tensorflow as hvd
logger = get_logger(__name__)

"""
Saves the model
"""
if hvd.rank() != 0:
return
self.get_policy(name_behavior_id).save_model(self.get_step)
def export_model(self, name_behavior_id: str) -> None:

if hvd.rank() != 0:
return
policy = self.get_policy(name_behavior_id)
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
export_policy_model(settings, policy.graph, policy.sess)

4
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
import horovod.tensorflow as hvd
class TrainerController(object):

"""
Saves current model to checkpoint folder.
"""
if hvd.rank() != 0:
return
for brain_name in self.trainers.keys():
for name_behavior_id in self.brain_name_to_identifier[brain_name]:
self.trainers[brain_name].save_model(name_behavior_id)

正在加载...
取消
保存