浏览代码

Enable learning rate decay to be disabled (#2567)

/develop-gpu-test
GitHub 5 年前
当前提交
3683cc1c
共有 11 个文件被更改,包括 88 次插入19 次删除
  1. 8
      config/sac_trainer_config.yaml
  2. 3
      config/trainer_config.yaml
  3. 16
      docs/Training-PPO.md
  4. 15
      docs/Training-SAC.md
  5. 26
      ml-agents/mlagents/trainers/models.py
  6. 10
      ml-agents/mlagents/trainers/ppo/models.py
  7. 7
      ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
  8. 5
      ml-agents/mlagents/trainers/ppo/policy.py
  9. 10
      ml-agents/mlagents/trainers/sac/models.py
  10. 5
      ml-agents/mlagents/trainers/sac/policy.py
  11. 2
      ml-agents/mlagents/trainers/trainer.py

8
config/sac_trainer_config.yaml


hidden_units: 128
init_entcoef: 1.0
learning_rate: 3.0e-4
learning_rate_schedule: constant
max_steps: 5.0e4
memory_size: 256
normalize: false

time_horizon: 1000
hidden_units: 64
init_entcoef: 0.5
max_steps: 5.0e5
3DBallHardLearning:
normalize: true

max_steps: 5.0e5
buffer_size: 500000
normalize: true
max_steps: 2e5

summary_freq: 3000
train_interval: 3
num_layers: 3
max_steps: 5e5
max_steps: 1e6
hidden_units: 512
WalkerLearning:

time_horizon: 1000
batch_size: 128
buffer_size: 500000
max_steps: 1e6
max_steps: 2e5
summary_freq: 3000
HallwayLearning:

3
config/trainer_config.yaml


hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
learning_rate_schedule: linear
max_steps: 5.0e4
memory_size: 256
normalize: false

summary_freq: 1000
use_recurrent: false
vis_encode_type: simple
reward_signals:
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99

16
docs/Training-PPO.md


Typical Range: `1e-5` - `1e-3`
### (Optional) Learning Rate Schedule
`learning_rate_schedule` corresponds to how the learning rate is changed over time.
For PPO, we recommend decaying learning rate until `max_steps` so learning converges
more stably. However, for some cases (e.g. training for an unknown amount of time)
this feature can be disabled.
Options:
* `linear` (default): Decay `learning_rate` linearly, reaching 0 at `max_steps`.
* `constant`: Keep learning rate constant for the entire training run.
Options: `linear`, `constant`
### Time Horizon
`time_horizon` corresponds to how many steps of experience to collect per-agent

### Learning Rate
This will decrease over time on a linear schedule.
This will decrease over time on a linear schedule by default, unless `learning_rate_schedule`
is set to `constant`.
### Policy Loss

15
docs/Training-SAC.md


Typical Range: `1e-5` - `1e-3`
### (Optional) Learning Rate Schedule
`learning_rate_schedule` corresponds to how the learning rate is changed over time.
For SAC, we recommend holding learning rate constant so that the agent can continue to
learn until its Q function converges naturally.
Options:
* `linear`: Decay `learning_rate` linearly, reaching 0 at `max_steps`.
* `constant` (default): Keep learning rate constant for the entire training run.
Options: `linear`, `constant`
### Time Horizon
`time_horizon` corresponds to how many steps of experience to collect per-agent

### Learning Rate
This will decrease over time on a linear schedule.
This will stay a constant value by default, unless `learning_rate_schedule`
is set to `linear`.
### Policy Loss

26
ml-agents/mlagents/trainers/models.py


import tensorflow as tf
import tensorflow.contrib.layers as c_layers
from mlagents.trainers.trainer import UnityTrainerException
logger = logging.getLogger("mlagents.trainers")
ActivationFunction = Callable[[tf.Tensor], tf.Tensor]

SIMPLE = "simple"
NATURE_CNN = "nature_cnn"
RESNET = "resnet"
class LearningRateSchedule(Enum):
CONSTANT = "constant"
LINEAR = "linear"
class LearningModel(object):

)
increment_step = tf.assign(global_step, tf.add(global_step, steps_to_increment))
return global_step, increment_step, steps_to_increment
@staticmethod
def create_learning_rate(
lr_schedule: LearningRateSchedule,
lr: float,
global_step: tf.Tensor,
max_step: int,
) -> tf.Tensor:
if lr_schedule == LearningRateSchedule.CONSTANT:
learning_rate = tf.Variable(lr)
elif lr_schedule == LearningRateSchedule.LINEAR:
learning_rate = tf.train.polynomial_decay(
lr, global_step, max_step, 1e-10, power=1.0
)
else:
raise UnityTrainerException(
"The learning rate schedule {} is invalid.".format(lr_schedule)
)
return learning_rate
@staticmethod
def scaled_init(scale):

10
ml-agents/mlagents/trainers/ppo/models.py


import numpy as np
import tensorflow as tf
from mlagents.trainers.models import LearningModel, EncoderType
from mlagents.trainers.models import LearningModel, EncoderType, LearningRateSchedule
logger = logging.getLogger("mlagents.trainers")

self,
brain,
lr=1e-4,
lr_schedule=LearningRateSchedule.LINEAR,
h_size=128,
epsilon=0.2,
beta=1e-3,

appropriate PPO agent model for the environment.
:param brain: BrainInfo used to generate specific network graph.
:param lr: Learning rate.
:param lr_schedule: Learning rate decay schedule.
:param h_size: Size of hidden layers
:param epsilon: Value for policy-divergence threshold.
:param beta: Strength of entropy regularization.

self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy
else:
self.create_dc_actor_critic(h_size, num_layers, vis_encode_type)
self.learning_rate = self.create_learning_rate(
lr_schedule, lr, self.global_step, max_step
)
self.create_losses(
self.log_probs,
self.old_log_probs,

shape=[None], dtype=tf.float32, name="advantages"
)
advantage = tf.expand_dims(self.advantage, -1)
self.learning_rate = tf.train.polynomial_decay(
lr, self.global_step, max_step, 1e-10, power=1.0
)
decay_epsilon = tf.train.polynomial_decay(
epsilon, self.global_step, max_step, 0.1, power=1.0

7
ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py


import tensorflow as tf
from tensorflow.python.client import device_lib
from mlagents.envs.timers import timed
from mlagents.trainers.models import EncoderType
from mlagents.trainers.models import EncoderType, LearningRateSchedule
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.components.reward_signals.reward_signal_factory import (

PPOModel(
brain=brain,
lr=float(trainer_params["learning_rate"]),
lr_schedule=LearningRateSchedule(
trainer_params.get(
"learning_rate_schedule", "linear"
)
),
h_size=int(trainer_params["hidden_units"]),
epsilon=float(trainer_params["epsilon"]),
beta=float(trainer_params["beta"]),

5
ml-agents/mlagents/trainers/ppo/policy.py


from mlagents.envs.timers import timed
from mlagents.trainers import BrainInfo, ActionInfo
from mlagents.trainers.models import EncoderType
from mlagents.trainers.models import EncoderType, LearningRateSchedule
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.components.reward_signals.reward_signal_factory import (

self.model = PPOModel(
brain=brain,
lr=float(trainer_params["learning_rate"]),
lr_schedule=LearningRateSchedule(
trainer_params.get("learning_rate_schedule", "linear")
),
h_size=int(trainer_params["hidden_units"]),
epsilon=float(trainer_params["epsilon"]),
beta=float(trainer_params["beta"]),

10
ml-agents/mlagents/trainers/sac/models.py


import numpy as np
import tensorflow as tf
from mlagents.trainers.models import LearningModel, EncoderType
from mlagents.trainers.models import LearningModel, LearningRateSchedule, EncoderType
import tensorflow.contrib.layers as c_layers
LOG_STD_MAX = 2

self,
brain,
lr=1e-4,
lr_schedule=LearningRateSchedule.CONSTANT,
h_size=128,
init_entcoef=0.1,
max_step=5e6,

appropriate PPO agent model for the environment.
:param brain: BrainInfo used to generate specific network graph.
:param lr: Learning rate.
:param lr_schedule: Learning rate decay schedule.
:param h_size: Size of hidden layers
:param init_entcoef: Initial value for entropy coefficient. Set lower to learn faster,
set higher to explore more.

vis_encode_type=vis_encode_type,
)
self.create_inputs_and_outputs()
self.learning_rate = self.create_learning_rate(
lr_schedule, lr, self.global_step, max_step
)
self.create_losses(
self.policy_network.q1_heads,
self.policy_network.q2_heads,

shape=[None], dtype=tf.float32, name="{}_rewards".format(name)
)
self.rewards_holders[name] = rewards_holder
self.learning_rate = tf.train.polynomial_decay(
lr, self.global_step, max_step, 1e-10, power=1.0
)
q1_losses = []
q2_losses = []

5
ml-agents/mlagents/trainers/sac/policy.py


from mlagents.envs.timers import timed
from mlagents.trainers import BrainInfo, ActionInfo, BrainParameters
from mlagents.trainers.models import EncoderType
from mlagents.trainers.models import EncoderType, LearningRateSchedule
from mlagents.trainers.sac.models import SACModel
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.components.reward_signals.reward_signal_factory import (

self.model = SACModel(
brain,
lr=float(trainer_params["learning_rate"]),
lr_schedule=LearningRateSchedule(
trainer_params.get("learning_rate_schedule", "constant")
),
h_size=int(trainer_params["hidden_units"]),
init_entcoef=float(trainer_params["init_entcoef"]),
max_step=float(trainer_params["max_steps"]),

2
ml-agents/mlagents/trainers/trainer.py


from mlagents.envs import UnityException, AllBrainInfo, ActionInfoOutputs, BrainInfo
from mlagents.envs.timers import set_gauge
from mlagents.trainers import TrainerMetrics
from mlagents.trainers.trainer_metrics import TrainerMetrics
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.tf_policy import Policy
from mlagents.envs import BrainParameters

正在加载...
取消
保存