比较提交

...
此合并请求有变更与目标分支冲突。
/config/ppo/3DBall.yaml
/config/sac/3DBall.yaml
/ml-agents/mlagents/trainers/learn.py
/ml-agents/mlagents/trainers/trainer_controller.py
/ml-agents/mlagents/trainers/stats.py
/ml-agents/mlagents/trainers/ppo/trainer.py
/ml-agents/mlagents/trainers/sac/trainer.py
/ml-agents/mlagents/trainers/trainer/rl_trainer.py
/ml-agents/mlagents/tf_utils/tf.py
/ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
/ml-agents/mlagents/trainers/policy/tf_policy.py

1 次代码提交

作者 SHA1 备注 提交日期
Anupam Bhatnagar 24d5f881 first commit 5 年前
共有 11 个文件被更改,包括 53 次插入12 次删除
  1. 2
      config/ppo/3DBall.yaml
  2. 2
      config/sac/3DBall.yaml
  3. 2
      ml-agents/mlagents/tf_utils/tf.py
  4. 3
      ml-agents/mlagents/trainers/learn.py
  5. 17
      ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
  6. 4
      ml-agents/mlagents/trainers/policy/tf_policy.py
  7. 7
      ml-agents/mlagents/trainers/ppo/trainer.py
  8. 4
      ml-agents/mlagents/trainers/sac/trainer.py
  9. 7
      ml-agents/mlagents/trainers/stats.py
  10. 10
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  11. 7
      ml-agents/mlagents/trainers/trainer_controller.py

2
config/ppo/3DBall.yaml


gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 500000
max_steps: 100000
time_horizon: 1000
summary_freq: 12000
threaded: true

2
config/sac/3DBall.yaml


gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 200000
max_steps: 100000
time_horizon: 1000
summary_freq: 12000
threaded: true

2
ml-agents/mlagents/tf_utils/tf.py


# Everywhere else is caught by the banned-modules setting for flake8
import tensorflow as tf # noqa I201
from distutils.version import LooseVersion
import horovod.tensorflow as hvd
# LooseVersion handles things "1.2.3a" or "4.5.6-rc7" fairly sensibly.

"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
# For multi-GPU training, set allow_soft_placement to True to allow
# placing the operation into an alternative device automatically
# to prevent from exceptions if the device doesn't suppport the operation

3
ml-agents/mlagents/trainers/learn.py


add_metadata as add_timer_metadata,
)
from mlagents_envs import logging_util
import horovod.tensorflow as hvd
logger = logging_util.get_logger(__name__)

options.curriculum, env_manager, restore=checkpoint_settings.resume
)
maybe_add_samplers(options.parameter_randomization, env_manager, run_seed)
hvd.init()
trainer_factory = TrainerFactory(
options.behaviors,
write_path,

17
ml-agents/mlagents/trainers/optimizer/tf_optimizer.py


from mlagents.trainers.settings import TrainerSettings, RewardSignalType
from mlagents.trainers.components.bc.module import BCModule
try:
import horovod.tensorflow as hvd
except ImportError:
hvd = None
class TFOptimizer(Optimizer): # pylint: disable=W0223
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings):

def create_optimizer_op(
self, learning_rate: tf.Tensor, name: str = "Adam"
) -> tf.train.Optimizer:
return tf.train.AdamOptimizer(learning_rate=learning_rate, name=name)
if hvd is not None:
adam_optimizer = tf.train.AdamOptimizer(
learning_rate=learning_rate, name=name
)
horovod_optimizer = hvd.DistributedOptimizer(adam_optimizer)
else:
adam_optimizer = tf.train.AdamOptimizer(
learning_rate=learning_rate, name=name
)
return horovod_optimizer if hvd is not None else adam_optimizer
def _execute_model(
self, feed_dict: Dict[tf.Tensor, np.ndarray], out_dict: Dict[str, tf.Tensor]

4
ml-agents/mlagents/trainers/policy/tf_policy.py


from mlagents.trainers.settings import TrainerSettings, NetworkSettings
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers import __version__
import horovod.tensorflow as hvd
logger = get_logger(__name__)

self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self._initialize_graph()
self.sess.run(hvd.broadcast_global_variables(0))
def get_weights(self):
with self.graph.as_default():

:param steps: The number of steps the model was trained for
:return:
"""
if hvd.rank() != 0:
return
with self.graph.as_default():
last_checkpoint = os.path.join(self.model_path, f"model-{steps}.ckpt")
self.saver.save(self.sess, last_checkpoint)

7
ml-agents/mlagents/trainers/ppo/trainer.py


Uses demonstration_buffer to update the policy.
The reward signal generators must be updated in this method at their own pace.
"""
buffer_length = self.update_buffer.num_experiences
self._maybe_write_summary(self.get_step + self.hyperparameters.buffer_size)
self._maybe_save_model(self.get_step + self.hyperparameters.buffer_size)
self._increment_step(self.hyperparameters.buffer_size, self.brain_name)
# Make sure batch_size is a multiple of sequence length. During training, we
# will need to reshape the data into a batch_size x sequence_length tensor.

for _ in range(num_epoch):
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer
max_num_batch = buffer_length // batch_size
max_num_batch = self.hyperparameters.buffer_size // batch_size
for i in range(0, max_num_batch * batch_size, batch_size):
update_stats = self.optimizer.update(
buffer.make_mini_batch(i, i + batch_size), n_sequences

4
ml-agents/mlagents/trainers/sac/trainer.py


"""
has_updated = False
self.cumulative_returns_since_policy_update.clear()
self._maybe_write_summary(self.get_step + int(self.steps_per_update))
self._maybe_save_model(self.get_step + int(self.steps_per_update))
self._increment_step(self.hyperparameters.buffer_size, self.brain_name)
n_sequences = max(
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)

7
ml-agents/mlagents/trainers/stats.py


from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import set_gauge
from mlagents.tf_utils import tf, generate_session_config
import horovod.tensorflow as hvd
logger = get_logger(__name__)

) -> None:
is_training = "Not Training."
if "Is Training" in values:
stats_summary = stats_summary = values["Is Training"]
stats_summary = values["Is Training"]
rank = hvd.rank()
"{}: Step: {}. "
"Horovod Rank: {}. {}: Step: {}. "
rank,
category,
step,
time.time() - self.training_start_time,

10
ml-agents/mlagents/trainers/trainer/rl_trainer.py


"""
return False
@abc.abstractmethod
def _update_policy(self) -> bool:
"""
Uses demonstration_buffer to update model.

Takes a trajectory and processes it, putting it into the update buffer.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.get_step + len(trajectory.steps))
self._maybe_save_model(self.get_step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
pass
# self._maybe_write_summary(self.get_step + len(trajectory.steps))
# self._maybe_save_model(self.get_step + len(trajectory.steps))
# self._increment_step(len(trajectory.steps), trajectory.behavior_id)
def _maybe_write_summary(self, step_after_process: int) -> None:
"""

"""
if self._next_summary_step == 0: # Don't write out the first one
self._next_summary_step = self._get_next_interval_step(self.summary_freq)
if step_after_process >= self._next_summary_step and self.get_step != 0:
if step_after_process >= self._next_summary_step:
self._write_summary(self._next_summary_step)
def _maybe_save_model(self, step_after_process: int) -> None:

7
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.trainers.agent_processor import AgentManager
from mlagents.trainers.settings import CurriculumSettings
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
import horovod.tensorflow as hvd
class TrainerController(object):

"""
Saves current model to checkpoint folder.
"""
if hvd.rank() != 0:
return
for brain_name in self.trainers.keys():
for name_behavior_id in self.brain_name_to_identifier[brain_name]:
self.trainers[brain_name].save_model(name_behavior_id)

"""
Exports latest saved models to .nn format for Unity embedding.
"""
if hvd.rank() != 0:
return
for brain_name in self.trainers.keys():
for name_behavior_id in self.brain_name_to_identifier[brain_name]:
self.trainers[brain_name].export_model(name_behavior_id)

正在加载...
取消
保存