浏览代码

Refactor reward signals into separate class (#2144)

* Create new class (RewardSignal) that represents a reward signal. 
* Add value heads for each reward signal in the PPO model.
* Make summaries agnostic to the type of reward signals, and log weighted rewards per reward signal. 
* Move extrinsic and curiosity rewards into this new structure.
* Allow defining multiple reward signals in YAML file. Add documentation for this new structure.
/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
4ac79742
共有 27 个文件被更改,包括 1302 次插入636 次删除
  1. 56
      config/trainer_config.yaml
  2. 45
      docs/Training-PPO.md
  3. 2
      ml-agents/mlagents/trainers/bc/policy.py
  4. 65
      ml-agents/mlagents/trainers/demo_loader.py
  5. 122
      ml-agents/mlagents/trainers/models.py
  6. 2
      ml-agents/mlagents/trainers/policy.py
  7. 222
      ml-agents/mlagents/trainers/ppo/models.py
  8. 138
      ml-agents/mlagents/trainers/ppo/policy.py
  9. 231
      ml-agents/mlagents/trainers/ppo/trainer.py
  10. 158
      ml-agents/mlagents/trainers/tests/test_ppo.py
  11. 22
      ml-agents/mlagents/trainers/trainer.py
  12. 114
      docs/Training-RewardSignals.md
  13. 92
      ml-agents/mlagents/trainers/tests/mock_brain.py
  14. 156
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  15. 0
      ml-agents/mlagents/trainers/components/__init__.py
  16. 1
      ml-agents/mlagents/trainers/components/reward_signals/__init__.py
  17. 1
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/__init__.py
  18. 179
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
  19. 178
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  20. 1
      ml-agents/mlagents/trainers/components/reward_signals/extrinsic/__init__.py
  21. 42
      ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
  22. 68
      ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py
  23. 43
      ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py

56
config/trainer_config.yaml


beta: 5.0e-3
buffer_size: 10240
epsilon: 0.2
gamma: 0.99
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4

sequence_length: 64
summary_freq: 1000
use_recurrent: false
use_curiosity: false
curiosity_strength: 0.01
curiosity_enc_size: 128
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
batch_size: 1024
batch_size: 1024
buffer_size: 10240
max_steps: 1.0e5

normalize: false
PyramidsLearning:
use_curiosity: true
curiosity_strength: 0.01
curiosity_enc_size: 256
time_horizon: 128
batch_size: 128
buffer_size: 2048

max_steps: 5.0e5
num_epoch: 3
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:
strength: 0.02
gamma: 0.99
encoding_size: 256
use_curiosity: true
curiosity_strength: 0.01
curiosity_enc_size: 256
time_horizon: 128
batch_size: 64
buffer_size: 2024

max_steps: 5.0e5
num_epoch: 3
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:
strength: 0.01
gamma: 0.99
encoding_size: 256
3DBallLearning:
normalize: true

time_horizon: 1000
lambd: 0.99
gamma: 0.995
beta: 0.001
3DBallHardLearning:

summary_freq: 1000
time_horizon: 1000
max_steps: 5.0e5
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
TennisLearning:
normalize: true

time_horizon: 1000
batch_size: 2024
buffer_size: 20240
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
CrawlerDynamicLearning:
normalize: true

buffer_size: 20240
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
WalkerLearning:
normalize: true

buffer_size: 20480
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
ReacherLearning:
normalize: true

hidden_units: 128
memory_size: 256
beta: 1.0e-2
gamma: 0.99
num_epoch: 3
buffer_size: 1024
batch_size: 128

45
docs/Training-PPO.md


ML-Agents PPO algorithm is implemented in TensorFlow and runs in a separate
Python process (communicating with the running Unity application over a socket).
To train an agent, you will need to provide the agent one or more reward signals which
the agent should attempt to maximize. See [Reward Signals](Training-RewardSignals.md)
for the available reward signals and the corresponding hyperparameters.
See [Training ML-Agents](Training-ML-Agents.md) for instructions on running the
training program, `learn.py`.

## Hyperparameters
### Gamma
### Reward Signals
`gamma` corresponds to the discount factor for future rewards. This can be
thought of as how far into the future the agent should care about possible
rewards. In situations when the agent should be acting in the present in order
to prepare for rewards in the distant future, this value should be large. In
cases when rewards are more immediate, it can be smaller.
In reinforcement learning, the goal is to learn a Policy that maximizes reward.
At a base level, the reward is given by the environment. However, we could imagine
rewarding the agent for various different behaviors. For instance, we could reward
the agent for exploring new states, rather than just when an explicit reward is given.
Furthermore, we could mix reward signals to help the learning process.
Typical Range: `0.8` - `0.995`
`reward_signals` provides a section to define [reward signals.](Training-RewardSignals.md)
ML-Agents provides two reward signals by default, the Extrinsic (environment) reward, and the
Curiosity reward, which can be used to encourage exploration in sparse extrinsic reward
environments.
### Lambda

the agent will need to remember in order to successfully complete the task.
Typical Range: `64` - `512`
## (Optional) Intrinsic Curiosity Module Hyperparameters
The below hyperparameters are only used when `use_curiosity` is set to true.
### Curiosity Encoding Size
`curiosity_enc_size` corresponds to the size of the hidden layer used to encode
the observations within the intrinsic curiosity module. This value should be
small enough to encourage the curiosity module to compress the original
observation, but also not too small to prevent it from learning the dynamics of
the environment.
Typical Range: `64` - `256`
### Curiosity Strength
`curiosity_strength` corresponds to the magnitude of the intrinsic reward
generated by the intrinsic curiosity module. This should be scaled in order to
ensure it is large enough to not be overwhelmed by extrinsic reward signals in
the environment. Likewise it should not be too large to overwhelm the extrinsic
reward signal.
Typical Range: `0.1` - `0.001`
## Training Statistics

2
ml-agents/mlagents/trainers/bc/policy.py


self.model.sequence_length: 1,
}
feed_dict = self._fill_eval_dict(feed_dict, brain_info)
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
if self.use_recurrent:
if brain_info.memories.shape[1] == 0:
brain_info.memories = self.make_empty_memory(len(brain_info.agents))

65
ml-agents/mlagents/trainers/demo_loader.py


current_brain_info.vector_observations[0]
)
demo_buffer[0]["actions"].append(next_brain_info.previous_vector_actions[0])
demo_buffer[0]["prev_action"].append(
current_brain_info.previous_vector_actions[0]
)
if next_brain_info.local_done[0]:
demo_buffer.append_update_buffer(
0, batch_size=None, training_length=sequence_length

# First 32 bytes of file dedicated to meta-data.
INITIAL_POS = 33
if not os.path.isfile(file_path):
file_paths = []
if os.path.isdir(file_path):
all_files = os.listdir(file_path)
for _file in all_files:
if _file.endswith(".demo"):
file_paths.append(_file)
elif os.path.isfile(file_path):
file_paths.append(file_path)
else:
"The demonstration file {} does not exist.".format(file_path)
"The demonstration file or directory {} does not exist.".format(file_path)
)
file_extension = pathlib.Path(file_path).suffix
if file_extension != ".demo":

brain_params = None
brain_infos = []
data = open(file_path, "rb").read()
next_pos, pos, obs_decoded = 0, 0, 0
total_expected = 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
total_expected = meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
brain_params = BrainParameters.from_proto(brain_param_proto)
pos += next_pos
if obs_decoded > 1:
agent_info = AgentInfoProto()
agent_info.ParseFromString(data[pos : pos + next_pos])
brain_info = BrainInfo.from_agent_proto([agent_info], brain_params)
brain_infos.append(brain_info)
if len(brain_infos) == total_expected:
break
pos += next_pos
obs_decoded += 1
for _file_path in file_paths:
data = open(_file_path, "rb").read()
next_pos, pos, obs_decoded = 0, 0, 0
total_expected = 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
total_expected = meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
brain_params = BrainParameters.from_proto(brain_param_proto)
pos += next_pos
if obs_decoded > 1:
agent_info = AgentInfoProto()
agent_info.ParseFromString(data[pos : pos + next_pos])
brain_info = BrainInfo.from_agent_proto([agent_info], brain_params)
brain_infos.append(brain_info)
if len(brain_infos) == total_expected:
break
pos += next_pos
obs_decoded += 1
return brain_params, brain_infos, total_expected

122
ml-agents/mlagents/trainers/models.py


class LearningModel(object):
_version_number_ = 2
def __init__(self, m_size, normalize, use_recurrent, brain, seed):
def __init__(
self, m_size, normalize, use_recurrent, brain, seed, stream_names=None
):
tf.set_random_seed(seed)
self.brain = brain
self.vector_in = None

)
self.mask_input = tf.placeholder(shape=[None], dtype=tf.float32, name="masks")
self.mask = tf.cast(self.mask_input, tf.int32)
self.stream_names = stream_names or []
self.use_recurrent = use_recurrent
if self.use_recurrent:
self.m_size = m_size

return global_step, increment_step
@staticmethod
def scaled_init(scale):
return c_layers.variance_scaling_initializer(scale)
@staticmethod
def swish(input_activation):
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
return tf.multiply(input_activation, tf.nn.sigmoid(input_activation))

shape=[None, self.vec_obs_size], dtype=tf.float32, name=name
)
if self.normalize:
self.running_mean = tf.get_variable(
"running_mean",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
self.running_variance = tf.get_variable(
"running_variance",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.ones_initializer(),
)
self.update_mean, self.update_variance = self.create_normalizer_update(
self.vector_in
)
self.normalized_state = tf.clip_by_value(
(self.vector_in - self.running_mean)
/ tf.sqrt(
self.running_variance / (tf.cast(self.global_step, tf.float32) + 1)
),
-5,
5,
name="normalized_state",
)
return self.normalized_state
self.create_normalizer(self.vector_in)
return self.normalize_vector_obs(self.vector_in)
def normalize_vector_obs(self, vector_obs):
normalized_state = tf.clip_by_value(
(vector_obs - self.running_mean)
/ tf.sqrt(
self.running_variance
/ (tf.cast(self.normalization_steps, tf.float32) + 1)
),
-5,
5,
name="normalized_state",
)
return normalized_state
def create_normalizer(self, vector_obs):
self.normalization_steps = tf.get_variable(
"normalization_steps",
[],
trainable=False,
dtype=tf.int32,
initializer=tf.ones_initializer(),
)
self.running_mean = tf.get_variable(
"running_mean",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
self.running_variance = tf.get_variable(
"running_variance",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.ones_initializer(),
)
self.update_normalization = self.create_normalizer_update(vector_obs)
) / tf.cast(tf.add(self.global_step, 1), tf.float32)
) / tf.cast(tf.add(self.normalization_steps, 1), tf.float32)
return update_mean, update_variance
update_norm_step = tf.assign(
self.normalization_steps, self.normalization_steps + 1
)
return tf.group([update_mean, update_variance, update_norm_step])
@staticmethod
def create_vector_observation_encoder(

m_size = memory_in.get_shape().as_list()[1]
lstm_input_state = tf.reshape(input_state, shape=[-1, sequence_length, s_size])
memory_in = tf.reshape(memory_in[:, :], [-1, m_size])
_half_point = int(m_size / 2)
half_point = int(m_size / 2)
rnn_cell = tf.contrib.rnn.BasicLSTMCell(_half_point)
rnn_cell = tf.contrib.rnn.BasicLSTMCell(half_point)
memory_in[:, :_half_point], memory_in[:, _half_point:]
memory_in[:, :half_point], memory_in[:, half_point:]
recurrent_output = tf.reshape(recurrent_output, shape=[-1, _half_point])
recurrent_output = tf.reshape(recurrent_output, shape=[-1, half_point])
def create_value_heads(self, stream_names, hidden_input):
"""
Creates one value estimator head for each reward signal in stream_names.
Also creates the node corresponding to the mean of all the value heads in self.value.
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal.
:param stream_names: The list of reward signal names
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
"""
self.value_heads = {}
for name in stream_names:
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name))
self.value_heads[name] = value
self.value = tf.reduce_mean(list(self.value_heads.values()), 0)
def create_cc_actor_critic(self, h_size, num_layers):
"""
Creates Continuous control actor-critic model.

kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01),
)
log_sigma_sq = tf.get_variable(
self.log_sigma_sq = tf.get_variable(
"log_sigma_squared",
[self.act_size[0]],
dtype=tf.float32,

sigma_sq = tf.exp(log_sigma_sq)
sigma_sq = tf.exp(self.log_sigma_sq)
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"

all_probs = (
-0.5 * tf.square(tf.stop_gradient(self.output_pre) - mu) / sigma_sq
- 0.5 * tf.log(2.0 * np.pi)
- 0.5 * log_sigma_sq
- 0.5 * self.log_sigma_sq
self.entropy = 0.5 * tf.reduce_mean(tf.log(2 * np.pi * np.e) + log_sigma_sq)
self.entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + self.log_sigma_sq
)
value = tf.layers.dense(hidden_value, 1, activation=None)
self.value = tf.identity(value, name="value_estimate")
self.create_value_heads(self.stream_names, hidden_value)
self.all_old_log_probs = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="old_probabilities"

self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
value = tf.layers.dense(hidden, 1, activation=None)
self.value = tf.identity(value, name="value_estimate")
self.create_value_heads(self.stream_names, hidden)
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"

2
ml-agents/mlagents/trainers/policy.py


run_out = dict(zip(list(out_dict.keys()), network_out))
return run_out
def _fill_eval_dict(self, feed_dict, brain_info):
def fill_eval_dict(self, feed_dict, brain_info):
for i, _ in enumerate(brain_info.visual_observations):
feed_dict[self.model.visual_in[i]] = brain_info.visual_observations[i]
if self.use_vec_obs:

222
ml-agents/mlagents/trainers/ppo/models.py


use_recurrent=False,
num_layers=2,
m_size=None,
use_curiosity=False,
curiosity_strength=0.01,
curiosity_enc_size=128,
stream_names=None,
):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the

:param h_size: Size of hidden layers
:param epsilon: Value for policy-divergence threshold.
:param beta: Strength of entropy regularization.
:return: a sub-class of PPOAgent tailored to the environment.
:param seed: Seed to use for initialization of model.
:param stream_names: List of names of value streams. Usually, a list of the Reward Signals being used.
:return: a sub-class of PPOAgent tailored to the environment.
LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed)
self.use_curiosity = use_curiosity
LearningModel.__init__(
self, m_size, normalize, use_recurrent, brain, seed, stream_names
)
if num_layers < 1:
num_layers = 1
self.last_reward, self.new_reward, self.update_reward = (

self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy
else:
self.create_dc_actor_critic(h_size, num_layers)
if self.use_curiosity:
self.curiosity_enc_size = curiosity_enc_size
self.curiosity_strength = curiosity_strength
encoded_state, encoded_next_state = self.create_curiosity_encoders()
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_ppo_optimizer(
self.create_losses(
self.value,
self.value_heads,
self.entropy,
beta,
epsilon,

update_reward = tf.assign(last_reward, new_reward)
return last_reward, new_reward, update_reward
def create_curiosity_encoders(self):
"""
Creates state encoders for current and future observations.
Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction
See https://arxiv.org/abs/1705.05363 for more details.
:return: current and future state encoder tensors.
"""
encoded_state_list = []
encoded_next_state_list = []
if self.vis_obs_size > 0:
self.next_visual_in = []
visual_encoders = []
next_visual_encoders = []
for i in range(self.vis_obs_size):
# Create input ops for next (t+1) visual observations.
next_visual_input = self.create_visual_input(
self.brain.camera_resolutions[i],
name="next_visual_observation_" + str(i),
)
self.next_visual_in.append(next_visual_input)
# Create the encoder ops for current and next visual input. Not that these encoders are siamese.
encoded_visual = self.create_visual_observation_encoder(
self.visual_in[i],
self.curiosity_enc_size,
self.swish,
1,
"stream_{}_visual_obs_encoder".format(i),
False,
)
encoded_next_visual = self.create_visual_observation_encoder(
self.next_visual_in[i],
self.curiosity_enc_size,
self.swish,
1,
"stream_{}_visual_obs_encoder".format(i),
True,
)
visual_encoders.append(encoded_visual)
next_visual_encoders.append(encoded_next_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
hidden_next_visual = tf.concat(next_visual_encoders, axis=1)
encoded_state_list.append(hidden_visual)
encoded_next_state_list.append(hidden_next_visual)
if self.vec_obs_size > 0:
# Create the encoder ops for current and next vector input. Not that these encoders are siamese.
# Create input op for next (t+1) vector observation.
self.next_vector_in = tf.placeholder(
shape=[None, self.vec_obs_size],
dtype=tf.float32,
name="next_vector_observation",
)
encoded_vector_obs = self.create_vector_observation_encoder(
self.vector_in,
self.curiosity_enc_size,
self.swish,
2,
"vector_obs_encoder",
False,
)
encoded_next_vector_obs = self.create_vector_observation_encoder(
self.next_vector_in,
self.curiosity_enc_size,
self.swish,
2,
"vector_obs_encoder",
True,
)
encoded_state_list.append(encoded_vector_obs)
encoded_next_state_list.append(encoded_next_vector_obs)
encoded_state = tf.concat(encoded_state_list, axis=1)
encoded_next_state = tf.concat(encoded_next_state_list, axis=1)
return encoded_state, encoded_next_state
def create_inverse_model(self, encoded_state, encoded_next_state):
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=self.swish)
if self.brain.vector_action_space_type == "continuous":
pred_action = tf.layers.dense(hidden, self.act_size[0], activation=None)
squared_difference = tf.reduce_sum(
tf.squared_difference(pred_action, self.selected_actions), axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.mask, 2)[1]
)
else:
pred_action = tf.concat(
[
tf.layers.dense(hidden, self.act_size[i], activation=tf.nn.softmax)
for i in range(len(self.act_size))
],
axis=1,
)
cross_entropy = tf.reduce_sum(
-tf.log(pred_action + 1e-10) * self.selected_actions, axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(cross_entropy, self.mask, 2)[1]
)
def create_forward_model(self, encoded_state, encoded_next_state):
"""
Creates forward model TensorFlow ops for Curiosity module.
Predicts encoded future state based on encoded current state and given action.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, self.selected_actions], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=self.swish)
# We compare against the concatenation of all observation streams, hence `self.vis_obs_size + int(self.vec_obs_size > 0)`.
pred_next_state = tf.layers.dense(
hidden,
self.curiosity_enc_size * (self.vis_obs_size + int(self.vec_obs_size > 0)),
activation=None,
)
squared_difference = 0.5 * tf.reduce_sum(
tf.squared_difference(pred_next_state, encoded_next_state), axis=1
)
self.intrinsic_reward = tf.clip_by_value(
self.curiosity_strength * squared_difference, 0, 1
)
self.forward_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.mask, 2)[1]
)
def create_ppo_optimizer(
self, probs, old_probs, value, entropy, beta, epsilon, lr, max_step
def create_losses(
self, probs, old_probs, value_heads, entropy, beta, epsilon, lr, max_step
:param value: Current value estimate
:param value_heads: Value estimate tensors from each value stream
:param beta: Entropy regularization strength
:param entropy: Current policy entropy
:param epsilon: Value for policy-divergence threshold

self.returns_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="discounted_rewards"
)
self.returns_holders = {}
self.old_values = {}
for name in value_heads.keys():
returns_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_returns".format(name)
)
old_value = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_value_estimate".format(name)
)
self.returns_holders[name] = returns_holder
self.old_values[name] = old_value
self.advantage = tf.placeholder(
shape=[None, 1], dtype=tf.float32, name="advantages"
)

self.old_value = tf.placeholder(
shape=[None], dtype=tf.float32, name="old_value_estimates"
)
decay_epsilon = tf.train.polynomial_decay(
epsilon, self.global_step, max_step, 0.1, power=1.0
)

optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
clipped_value_estimate = self.old_value + tf.clip_by_value(
tf.reduce_sum(value, axis=1) - self.old_value, -decay_epsilon, decay_epsilon
)
value_losses = []
for name, head in value_heads.items():
clipped_value_estimate = self.old_values[name] + tf.clip_by_value(
tf.reduce_sum(head, axis=1) - self.old_values[name],
-decay_epsilon,
decay_epsilon,
)
v_opt_a = tf.squared_difference(
self.returns_holders[name], tf.reduce_sum(head, axis=1)
)
v_opt_b = tf.squared_difference(
self.returns_holders[name], clipped_value_estimate
)
value_loss = tf.reduce_mean(
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]
)
value_losses.append(value_loss)
self.value_loss = tf.reduce_mean(value_losses)
v_opt_a = tf.squared_difference(
self.returns_holder, tf.reduce_sum(value, axis=1)
)
v_opt_b = tf.squared_difference(self.returns_holder, clipped_value_estimate)
self.value_loss = tf.reduce_mean(
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]
)
# Here we calculate PPO policy loss. In continuous control this is done independently for each action gaussian
# and then averaged together. This provides significantly better performance than treating the probability
# as an average of probabilities, or as a joint probability.
r_theta = tf.exp(probs - old_probs)
p_opt_a = r_theta * self.advantage
p_opt_b = (

* tf.reduce_mean(tf.dynamic_partition(entropy, self.mask, 2)[1])
)
if self.use_curiosity:
self.loss += 10 * (0.2 * self.forward_loss + 0.8 * self.inverse_loss)
def create_ppo_optimizer(self):
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.update_batch = optimizer.minimize(self.loss)

138
ml-agents/mlagents/trainers/ppo/policy.py


import logging
import numpy as np
from mlagents.trainers import BrainInfo, ActionInfo
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
logger = logging.getLogger("mlagents.trainers")

:param load: Whether a pre-trained model will be loaded or a new one created.
"""
super().__init__(seed, brain, trainer_params)
self.has_updated = False
self.use_curiosity = bool(trainer_params["use_curiosity"])
reward_signal_configs = trainer_params["reward_signals"]
self.reward_signals = {}
with self.graph.as_default():
self.model = PPOModel(
brain,

use_recurrent=trainer_params["use_recurrent"],
num_layers=int(trainer_params["num_layers"]),
m_size=self.m_size,
use_curiosity=bool(trainer_params["use_curiosity"]),
curiosity_strength=float(trainer_params["curiosity_strength"]),
curiosity_enc_size=float(trainer_params["curiosity_enc_size"]),
stream_names=list(reward_signal_configs.keys()),
self.model.create_ppo_optimizer()
# Create reward signals
for reward_signal, config in reward_signal_configs.items():
self.reward_signals[reward_signal] = create_reward_signal(
self, reward_signal, config
)
if load:
self._load_graph()

self.inference_dict = {
"action": self.model.output,
"log_probs": self.model.all_log_probs,
"value": self.model.value,
"value": self.model.value_heads,
"entropy": self.model.entropy,
"learning_rate": self.model.learning_rate,
}

self.inference_dict["memory_out"] = self.model.memory_out
if is_training and self.use_vec_obs and trainer_params["normalize"]:
self.inference_dict["update_mean"] = self.model.update_mean
self.inference_dict["update_variance"] = self.model.update_variance
if (
is_training
and self.use_vec_obs
and trainer_params["normalize"]
and not load
):
self.inference_dict["update_mean"] = self.model.update_normalization
self.total_policy_loss = self.model.policy_loss
"policy_loss": self.model.policy_loss,
"policy_loss": self.total_policy_loss,
if self.use_curiosity:
self.update_dict["forward_loss"] = self.model.forward_loss
self.update_dict["inverse_loss"] = self.model.inverse_loss
def evaluate(self, brain_info):
"""

size=(len(brain_info.vector_observations), self.model.act_size[0])
)
feed_dict[self.model.epsilon] = epsilon
feed_dict = self._fill_eval_dict(feed_dict, brain_info)
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
run_out = self._execute_model(feed_dict, self.inference_dict)
if self.use_continuous_act:
run_out["random_normal_epsilon"] = epsilon

self.model.batch_size: num_sequences,
self.model.sequence_length: self.sequence_length,
self.model.mask_input: mini_batch["masks"].flatten(),
self.model.returns_holder: mini_batch["discounted_returns"].flatten(),
self.model.old_value: mini_batch["value_estimates"].flatten(),
for name in self.reward_signals:
feed_dict[self.model.returns_holders[name]] = mini_batch[
"{}_returns".format(name)
].flatten()
feed_dict[self.model.old_values[name]] = mini_batch[
"{}_value_estimates".format(name)
].flatten()
if self.use_continuous_act:
feed_dict[self.model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, self.model.act_size[0]]

feed_dict[self.model.vector_in] = mini_batch["vector_obs"].reshape(
[-1, self.vec_obs_size]
)
if self.use_curiosity:
feed_dict[self.model.next_vector_in] = mini_batch[
"next_vector_in"
].reshape([-1, self.vec_obs_size])
if self.model.vis_obs_size > 0:
for i, _ in enumerate(self.model.visual_in):
_obs = mini_batch["visual_obs%d" % i]

else:
feed_dict[self.model.visual_in[i]] = _obs
if self.use_curiosity:
for i, _ in enumerate(self.model.visual_in):
_obs = mini_batch["next_visual_obs%d" % i]
if self.sequence_length > 1 and self.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.next_visual_in[i]] = _obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.model.next_visual_in[i]] = _obs
self.has_updated = True
def get_intrinsic_rewards(self, curr_info, next_info):
"""
Generates intrinsic reward used for Curiosity-based training.
:BrainInfo curr_info: Current BrainInfo.
:BrainInfo next_info: Next BrainInfo.
:return: Intrinsic rewards for all agents.
"""
if self.use_curiosity:
if len(curr_info.agents) == 0:
return []
feed_dict = {
self.model.batch_size: len(next_info.vector_observations),
self.model.sequence_length: 1,
}
if self.use_continuous_act:
feed_dict[
self.model.selected_actions
] = next_info.previous_vector_actions
else:
feed_dict[self.model.action_holder] = next_info.previous_vector_actions
for i in range(self.model.vis_obs_size):
feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i]
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[
i
]
if self.use_vec_obs:
feed_dict[self.model.vector_in] = curr_info.vector_observations
feed_dict[self.model.next_vector_in] = next_info.vector_observations
if self.use_recurrent:
if curr_info.memories.shape[1] == 0:
curr_info.memories = self.make_empty_memory(len(curr_info.agents))
feed_dict[self.model.memory_in] = curr_info.memories
intrinsic_rewards = self.sess.run(
self.model.intrinsic_reward, feed_dict=feed_dict
) * float(self.has_updated)
return intrinsic_rewards
else:
return None
def get_value_estimate(self, brain_info, idx):
def get_value_estimates(self, brain_info, idx):
:return: Value estimate.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
"""
feed_dict = {self.model.batch_size: 1, self.model.sequence_length: 1}
for i in range(len(brain_info.visual_observations)):

feed_dict[self.model.prev_action] = brain_info.previous_vector_actions[
idx
].reshape([-1, len(self.model.act_size)])
value_estimate = self.sess.run(self.model.value, feed_dict)
return value_estimate
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
return value_estimates
def get_action(self, brain_info: BrainInfo) -> ActionInfo:
"""
Decides actions given observations information, and takes them in environment.
:param brain_info: A dictionary of brain names and BrainInfo from environment.
:return: an ActionInfo containing action, memories, values and an object
to be passed to add experiences
"""
if len(brain_info.agents) == 0:
return ActionInfo([], [], [], None, None)
run_out = self.evaluate(brain_info)
mean_values = np.mean(
np.array(list(run_out.get("value").values())), axis=0
).flatten()
return ActionInfo(
action=run_out.get("action"),
memory=run_out.get("memory_out"),
text=None,
value=mean_values,
outputs=run_out,
)
def get_last_reward(self):
"""

231
ml-agents/mlagents/trainers/ppo/trainer.py


# # Unity ML-Agents Toolkit
# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described (https://arxiv.org/abs/1707.06347).
# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347
from collections import deque
from collections import deque, defaultdict
from typing import Any, List
import numpy as np

from mlagents.trainers.buffer import Buffer
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.action_info import ActionInfoOutputs
logger = logging.getLogger("mlagents.trainers")

"beta",
"buffer_size",
"epsilon",
"gamma",
"hidden_units",
"lambd",
"learning_rate",

"use_recurrent",
"summary_path",
"memory_size",
"use_curiosity",
"curiosity_strength",
"curiosity_enc_size",
"reward_signals",
self.check_param_keys()
self.check_param_keys()
self.use_curiosity = bool(trainer_parameters["use_curiosity"])
# Make sure we have at least one reward_signal
if not self.trainer_parameters["reward_signals"]:
raise UnityTrainerException(
"No reward signals were defined. At least one must be used with {}.".format(
self.__class__.__name__
)
)
stats = {
"Environment/Cumulative Reward": [],
"Environment/Episode Length": [],
"Policy/Value Estimate": [],
"Policy/Entropy": [],
"Losses/Value Loss": [],
"Losses/Policy Loss": [],
"Policy/Learning Rate": [],
}
if self.use_curiosity:
stats["Losses/Forward Loss"] = []
stats["Losses/Inverse Loss"] = []
stats["Policy/Curiosity Reward"] = []
self.intrinsic_rewards = {}
stats = defaultdict(list)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.collected_rewards = {"environment": {}}
for _reward_signal in self.policy.reward_signals.keys():
self.collected_rewards[_reward_signal] = {}
self.cumulative_rewards = {}
return """Hyperparameters for the PPO Trainer of brain {0}: \n{1}""".format(
return """Hyperparameters for the {0} of brain {1}: \n{2}""".format(
self.__class__.__name__,
"\n".join(
[
"\t{0}:\t{1}".format(x, self.trainer_parameters[x])
for x in self.param_keys
]
),
self.dict_to_str(self.trainer_parameters, 0),
)
@property

"""
Increment the step count of the trainer and Updates the last reward
"""
if len(self.stats["Environment/Cumulative Reward"]) > 0:
if self.stats["Environment/Cumulative Reward"]:
mean_reward = np.mean(self.stats["Environment/Cumulative Reward"])
self.policy.update_reward(mean_reward)
self.policy.increment_step()

"""
Constructs a BrainInfo which contains the most recent previous experiences for all agents info
Constructs a BrainInfo which contains the most recent previous experiences for all agents
which correspond to the agents in a provided next_info.
:BrainInfo next_info: A t+1 BrainInfo.
:return: curr_info: Reconstructed BrainInfo to match agents of next_info.

"""
self.trainer_metrics.start_experience_collection_timer()
if take_action_outputs:
self.stats["Policy/Value Estimate"].append(
take_action_outputs["value"].mean()
)
for name, signal in self.policy.reward_signals.items():
self.stats[signal.value_name].append(
np.mean(take_action_outputs["value"][name])
)
curr_info = curr_all_info[self.brain_name]
next_info = next_all_info[self.brain_name]

else:
curr_to_use = curr_info
intrinsic_rewards = self.policy.get_intrinsic_rewards(curr_to_use, next_info)
tmp_rewards_dict = {}
for name, signal in self.policy.reward_signals.items():
tmp_rewards_dict[name] = signal.evaluate(curr_to_use, next_info)
for agent_id in next_info.agents:
stored_info = self.training_buffer[agent_id].last_brain_info

stored_info.action_masks[idx], padding_value=1
)
a_dist = stored_take_action_outputs["log_probs"]
# value is a dictionary from name of reward to value estimate of the value head
value = stored_take_action_outputs["value"]
self.training_buffer[agent_id]["actions"].append(actions[idx])
self.training_buffer[agent_id]["prev_action"].append(

if self.use_curiosity:
self.training_buffer[agent_id]["rewards"].append(
next_info.rewards[next_idx] + intrinsic_rewards[next_idx]
)
else:
self.training_buffer[agent_id]["rewards"].append(
next_info.rewards[next_idx]
)
self.training_buffer[agent_id]["action_probs"].append(a_dist[idx])
self.training_buffer[agent_id]["value_estimates"].append(
value[idx][0]
self.training_buffer[agent_id]["done"].append(
next_info.local_done[next_idx]
if agent_id not in self.cumulative_rewards:
self.cumulative_rewards[agent_id] = 0
self.cumulative_rewards[agent_id] += next_info.rewards[next_idx]
if self.use_curiosity:
if agent_id not in self.intrinsic_rewards:
self.intrinsic_rewards[agent_id] = 0
self.intrinsic_rewards[agent_id] += intrinsic_rewards[next_idx]
for name, reward_result in tmp_rewards_dict.items():
# 0 because we use the scaled reward to train the agent
self.training_buffer[agent_id][
"{}_rewards".format(name)
].append(reward_result.scaled_reward[next_idx])
self.training_buffer[agent_id][
"{}_value_estimates".format(name)
].append(value[name][idx][0])
self.training_buffer[agent_id]["action_probs"].append(a_dist[idx])
for name, rewards in self.collected_rewards.items():
if agent_id not in rewards:
rewards[agent_id] = 0
if name == "environment":
# Report the reward from the environment
rewards[agent_id] += np.array(next_info.rewards)[next_idx]
else:
# Report the reward signals
rewards[agent_id] += tmp_rewards_dict[name].scaled_reward[
next_idx
]
if not next_info.local_done[next_idx]:
if agent_id not in self.episode_steps:
self.episode_steps[agent_id] = 0

:param current_info: Dictionary of all current brains and corresponding BrainInfo.
:param new_info: Dictionary of all next brains and corresponding BrainInfo.
"""
self.trainer_metrics.start_experience_collection_timer()
info = new_info[self.brain_name]
for l in range(len(info.agents)):
agent_actions = self.training_buffer[info.agents[l]]["actions"]

) and len(agent_actions) > 0:
agent_id = info.agents[l]
if info.max_reached[l]:
bootstrapping_info = self.training_buffer[agent_id].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
else:
bootstrapping_info = info
idx = l
value_next = self.policy.get_value_estimates(bootstrapping_info, idx)
value_next = 0.0
else:
if info.max_reached[l]:
bootstrapping_info = self.training_buffer[
agent_id
].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
else:
bootstrapping_info = info
idx = l
value_next = self.policy.get_value_estimate(bootstrapping_info, idx)
value_next["extrinsic"] = 0.0
tmp_advantages = []
tmp_returns = []
for name in self.policy.reward_signals:
bootstrap_value = value_next[name]
self.training_buffer[agent_id]["advantages"].set(
get_gae(
rewards=self.training_buffer[agent_id]["rewards"].get_batch(),
value_estimates=self.training_buffer[agent_id][
"value_estimates"
].get_batch(),
value_next=value_next,
gamma=self.trainer_parameters["gamma"],
local_rewards = self.training_buffer[agent_id][
"{}_rewards".format(name)
].get_batch()
local_value_estimates = self.training_buffer[agent_id][
"{}_value_estimates".format(name)
].get_batch()
local_advantage = get_gae(
rewards=local_rewards,
value_estimates=local_value_estimates,
value_next=bootstrap_value,
gamma=self.policy.reward_signals[name].gamma,
)
self.training_buffer[agent_id]["discounted_returns"].set(
self.training_buffer[agent_id]["advantages"].get_batch()
+ self.training_buffer[agent_id]["value_estimates"].get_batch()
)
local_return = local_advantage + local_value_estimates
# This is later use as target for the different value estimates
self.training_buffer[agent_id]["{}_returns".format(name)].set(
local_return
)
self.training_buffer[agent_id]["{}_advantage".format(name)].set(
local_advantage
)
tmp_advantages.append(local_advantage)
tmp_returns.append(local_return)
global_advantages = list(np.mean(np.array(tmp_advantages), axis=0))
global_returns = list(np.mean(np.array(tmp_returns), axis=0))
self.training_buffer[agent_id]["advantages"].set(global_advantages)
self.training_buffer[agent_id]["discounted_returns"].set(global_returns)
self.training_buffer.append_update_buffer(
agent_id,

self.training_buffer[agent_id].reset_agent()
if info.local_done[l]:
self.cumulative_returns_since_policy_update.append(
self.cumulative_rewards.get(agent_id, 0)
)
self.stats["Environment/Cumulative Reward"].append(
self.cumulative_rewards.get(agent_id, 0)
)
self.reward_buffer.appendleft(
self.cumulative_rewards.get(agent_id, 0)
)
self.cumulative_rewards[agent_id] = 0
if self.use_curiosity:
self.stats["Policy/Curiosity Reward"].append(
self.intrinsic_rewards.get(agent_id, 0)
)
self.intrinsic_rewards[agent_id] = 0
self.trainer_metrics.end_experience_collection_timer()
for name, rewards in self.collected_rewards.items():
if name == "environment":
self.cumulative_returns_since_policy_update.append(
rewards.get(agent_id, 0)
)
self.stats["Environment/Cumulative Reward"].append(
rewards.get(agent_id, 0)
)
rewards[agent_id] = 0
self.reward_buffer.appendleft(rewards.get(agent_id, 0))
else:
self.stats[
self.policy.reward_signals[name].stat_name
].append(rewards.get(agent_id, 0))
rewards[agent_id] = 0
def end_episode(self):
"""

self.training_buffer.reset_local_buffers()
for agent_id in self.cumulative_rewards:
self.cumulative_rewards[agent_id] = 0
if self.use_curiosity:
for agent_id in self.intrinsic_rewards:
self.intrinsic_rewards[agent_id] = 0
for rewards in self.collected_rewards.values():
for agent_id in rewards:
rewards[agent_id] = 0
def is_ready_update(self):
"""

def update_policy(self):
"""
Uses demonstration_buffer to update the policy.
The reward signal generators must be updated in this method at their own pace.
"""
self.trainer_metrics.start_policy_update_timer(
number_experiences=len(self.training_buffer.update_buffer["actions"]),

)
value_total.append(run_out["value_loss"])
policy_total.append(np.abs(run_out["policy_loss"]))
if self.use_curiosity:
inverse_total.append(run_out["inverse_loss"])
forward_total.append(run_out["forward_loss"])
if self.use_curiosity:
self.stats["Losses/Forward Loss"].append(np.mean(forward_total))
self.stats["Losses/Inverse Loss"].append(np.mean(inverse_total))
for _, reward_signal in self.policy.reward_signals.items():
update_stats = reward_signal.update(
self.training_buffer.update_buffer, n_sequences
)
for stat, val in update_stats.items():
self.stats[stat].append(val)
self.training_buffer.reset_update_buffer()
self.trainer_metrics.end_policy_update()

158
ml-agents/mlagents/trainers/tests/test_ppo.py


beta: 5.0e-3
buffer_size: 512
epsilon: 0.2
gamma: 0.99
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4

summary_freq: 1000
use_recurrent: false
memory_size: 8
use_curiosity: false
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)

model.memory_in: np.zeros((1, memory_size)),
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.epsilon: np.array([[0, 1]]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_model_dc_vector_curio(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=0
)
env = UnityEnvironment(" ")
model = PPOModel(env.brains["RealFakeBrain"], use_curiosity=True)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.intrinsic_reward,
]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.next_vector_in: np.array(
[[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]
),
model.action_holder: [[0], [0]],
model.action_masks: np.ones([2, 2]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_model_cc_vector_curio(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
model = PPOModel(env.brains["RealFakeBrain"], use_curiosity=True)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.intrinsic_reward,
]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,