浏览代码

Refactor reward signals into separate class (#2144)

* Create new class (RewardSignal) that represents a reward signal. 
* Add value heads for each reward signal in the PPO model.
* Make summaries agnostic to the type of reward signals, and log weighted rewards per reward signal. 
* Move extrinsic and curiosity rewards into this new structure.
* Allow defining multiple reward signals in YAML file. Add documentation for this new structure.
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
4ac79742
共有 27 个文件被更改,包括 1302 次插入636 次删除
  1. 56
      config/trainer_config.yaml
  2. 45
      docs/Training-PPO.md
  3. 2
      ml-agents/mlagents/trainers/bc/policy.py
  4. 65
      ml-agents/mlagents/trainers/demo_loader.py
  5. 122
      ml-agents/mlagents/trainers/models.py
  6. 2
      ml-agents/mlagents/trainers/policy.py
  7. 222
      ml-agents/mlagents/trainers/ppo/models.py
  8. 138
      ml-agents/mlagents/trainers/ppo/policy.py
  9. 231
      ml-agents/mlagents/trainers/ppo/trainer.py
  10. 158
      ml-agents/mlagents/trainers/tests/test_ppo.py
  11. 22
      ml-agents/mlagents/trainers/trainer.py
  12. 114
      docs/Training-RewardSignals.md
  13. 92
      ml-agents/mlagents/trainers/tests/mock_brain.py
  14. 156
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  15. 0
      ml-agents/mlagents/trainers/components/__init__.py
  16. 1
      ml-agents/mlagents/trainers/components/reward_signals/__init__.py
  17. 1
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/__init__.py
  18. 179
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
  19. 178
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  20. 1
      ml-agents/mlagents/trainers/components/reward_signals/extrinsic/__init__.py
  21. 42
      ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
  22. 68
      ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py
  23. 43
      ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py

56
config/trainer_config.yaml


beta: 5.0e-3
buffer_size: 10240
epsilon: 0.2
gamma: 0.99
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4

sequence_length: 64
summary_freq: 1000
use_recurrent: false
use_curiosity: false
curiosity_strength: 0.01
curiosity_enc_size: 128
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
batch_size: 1024
batch_size: 1024
buffer_size: 10240
max_steps: 1.0e5

normalize: false
PyramidsLearning:
use_curiosity: true
curiosity_strength: 0.01
curiosity_enc_size: 256
time_horizon: 128
batch_size: 128
buffer_size: 2048

max_steps: 5.0e5
num_epoch: 3
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:
strength: 0.02
gamma: 0.99
encoding_size: 256
use_curiosity: true
curiosity_strength: 0.01
curiosity_enc_size: 256
time_horizon: 128
batch_size: 64
buffer_size: 2024

max_steps: 5.0e5
num_epoch: 3
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:
strength: 0.01
gamma: 0.99
encoding_size: 256
3DBallLearning:
normalize: true

time_horizon: 1000
lambd: 0.99
gamma: 0.995
beta: 0.001
3DBallHardLearning:

summary_freq: 1000
time_horizon: 1000
max_steps: 5.0e5
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
TennisLearning:
normalize: true

time_horizon: 1000
batch_size: 2024
buffer_size: 20240
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
CrawlerDynamicLearning:
normalize: true

buffer_size: 20240
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
WalkerLearning:
normalize: true

buffer_size: 20480
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
ReacherLearning:
normalize: true

hidden_units: 128
memory_size: 256
beta: 1.0e-2
gamma: 0.99
num_epoch: 3
buffer_size: 1024
batch_size: 128

45
docs/Training-PPO.md


ML-Agents PPO algorithm is implemented in TensorFlow and runs in a separate
Python process (communicating with the running Unity application over a socket).
To train an agent, you will need to provide the agent one or more reward signals which
the agent should attempt to maximize. See [Reward Signals](Training-RewardSignals.md)
for the available reward signals and the corresponding hyperparameters.
See [Training ML-Agents](Training-ML-Agents.md) for instructions on running the
training program, `learn.py`.

## Hyperparameters
### Gamma
### Reward Signals
`gamma` corresponds to the discount factor for future rewards. This can be
thought of as how far into the future the agent should care about possible
rewards. In situations when the agent should be acting in the present in order
to prepare for rewards in the distant future, this value should be large. In
cases when rewards are more immediate, it can be smaller.
In reinforcement learning, the goal is to learn a Policy that maximizes reward.
At a base level, the reward is given by the environment. However, we could imagine
rewarding the agent for various different behaviors. For instance, we could reward
the agent for exploring new states, rather than just when an explicit reward is given.
Furthermore, we could mix reward signals to help the learning process.
Typical Range: `0.8` - `0.995`
`reward_signals` provides a section to define [reward signals.](Training-RewardSignals.md)
ML-Agents provides two reward signals by default, the Extrinsic (environment) reward, and the
Curiosity reward, which can be used to encourage exploration in sparse extrinsic reward
environments.
### Lambda

the agent will need to remember in order to successfully complete the task.
Typical Range: `64` - `512`
## (Optional) Intrinsic Curiosity Module Hyperparameters
The below hyperparameters are only used when `use_curiosity` is set to true.
### Curiosity Encoding Size
`curiosity_enc_size` corresponds to the size of the hidden layer used to encode
the observations within the intrinsic curiosity module. This value should be
small enough to encourage the curiosity module to compress the original
observation, but also not too small to prevent it from learning the dynamics of
the environment.
Typical Range: `64` - `256`
### Curiosity Strength
`curiosity_strength` corresponds to the magnitude of the intrinsic reward
generated by the intrinsic curiosity module. This should be scaled in order to
ensure it is large enough to not be overwhelmed by extrinsic reward signals in
the environment. Likewise it should not be too large to overwhelm the extrinsic
reward signal.
Typical Range: `0.1` - `0.001`
## Training Statistics

2
ml-agents/mlagents/trainers/bc/policy.py


self.model.sequence_length: 1,
}
feed_dict = self._fill_eval_dict(feed_dict, brain_info)
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
if self.use_recurrent:
if brain_info.memories.shape[1] == 0:
brain_info.memories = self.make_empty_memory(len(brain_info.agents))

65
ml-agents/mlagents/trainers/demo_loader.py


current_brain_info.vector_observations[0]
)
demo_buffer[0]["actions"].append(next_brain_info.previous_vector_actions[0])
demo_buffer[0]["prev_action"].append(
current_brain_info.previous_vector_actions[0]
)
if next_brain_info.local_done[0]:
demo_buffer.append_update_buffer(
0, batch_size=None, training_length=sequence_length

# First 32 bytes of file dedicated to meta-data.
INITIAL_POS = 33
if not os.path.isfile(file_path):
file_paths = []
if os.path.isdir(file_path):
all_files = os.listdir(file_path)
for _file in all_files:
if _file.endswith(".demo"):
file_paths.append(_file)
elif os.path.isfile(file_path):
file_paths.append(file_path)
else:
"The demonstration file {} does not exist.".format(file_path)
"The demonstration file or directory {} does not exist.".format(file_path)
)
file_extension = pathlib.Path(file_path).suffix
if file_extension != ".demo":

brain_params = None
brain_infos = []
data = open(file_path, "rb").read()
next_pos, pos, obs_decoded = 0, 0, 0
total_expected = 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
total_expected = meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
brain_params = BrainParameters.from_proto(brain_param_proto)
pos += next_pos
if obs_decoded > 1:
agent_info = AgentInfoProto()
agent_info.ParseFromString(data[pos : pos + next_pos])
brain_info = BrainInfo.from_agent_proto([agent_info], brain_params)
brain_infos.append(brain_info)
if len(brain_infos) == total_expected:
break
pos += next_pos
obs_decoded += 1
for _file_path in file_paths:
data = open(_file_path, "rb").read()
next_pos, pos, obs_decoded = 0, 0, 0
total_expected = 0
while pos < len(data):
next_pos, pos = _DecodeVarint32(data, pos)
if obs_decoded == 0:
meta_data_proto = DemonstrationMetaProto()
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
total_expected = meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
brain_params = BrainParameters.from_proto(brain_param_proto)
pos += next_pos
if obs_decoded > 1:
agent_info = AgentInfoProto()
agent_info.ParseFromString(data[pos : pos + next_pos])
brain_info = BrainInfo.from_agent_proto([agent_info], brain_params)
brain_infos.append(brain_info)
if len(brain_infos) == total_expected:
break
pos += next_pos
obs_decoded += 1
return brain_params, brain_infos, total_expected

122
ml-agents/mlagents/trainers/models.py


class LearningModel(object):
_version_number_ = 2
def __init__(self, m_size, normalize, use_recurrent, brain, seed):
def __init__(
self, m_size, normalize, use_recurrent, brain, seed, stream_names=None
):
tf.set_random_seed(seed)
self.brain = brain
self.vector_in = None

)
self.mask_input = tf.placeholder(shape=[None], dtype=tf.float32, name="masks")
self.mask = tf.cast(self.mask_input, tf.int32)
self.stream_names = stream_names or []
self.use_recurrent = use_recurrent
if self.use_recurrent:
self.m_size = m_size

return global_step, increment_step
@staticmethod
def scaled_init(scale):
return c_layers.variance_scaling_initializer(scale)
@staticmethod
def swish(input_activation):
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
return tf.multiply(input_activation, tf.nn.sigmoid(input_activation))

shape=[None, self.vec_obs_size], dtype=tf.float32, name=name
)
if self.normalize:
self.running_mean = tf.get_variable(
"running_mean",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
self.running_variance = tf.get_variable(
"running_variance",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.ones_initializer(),
)
self.update_mean, self.update_variance = self.create_normalizer_update(
self.vector_in
)
self.normalized_state = tf.clip_by_value(
(self.vector_in - self.running_mean)
/ tf.sqrt(
self.running_variance / (tf.cast(self.global_step, tf.float32) + 1)
),
-5,
5,
name="normalized_state",
)
return self.normalized_state
self.create_normalizer(self.vector_in)
return self.normalize_vector_obs(self.vector_in)
def normalize_vector_obs(self, vector_obs):
normalized_state = tf.clip_by_value(
(vector_obs - self.running_mean)
/ tf.sqrt(
self.running_variance
/ (tf.cast(self.normalization_steps, tf.float32) + 1)
),
-5,
5,
name="normalized_state",
)
return normalized_state
def create_normalizer(self, vector_obs):
self.normalization_steps = tf.get_variable(
"normalization_steps",
[],
trainable=False,
dtype=tf.int32,
initializer=tf.ones_initializer(),
)
self.running_mean = tf.get_variable(
"running_mean",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
self.running_variance = tf.get_variable(
"running_variance",
[self.vec_obs_size],
trainable=False,
dtype=tf.float32,
initializer=tf.ones_initializer(),
)
self.update_normalization = self.create_normalizer_update(vector_obs)
) / tf.cast(tf.add(self.global_step, 1), tf.float32)
) / tf.cast(tf.add(self.normalization_steps, 1), tf.float32)
return update_mean, update_variance
update_norm_step = tf.assign(
self.normalization_steps, self.normalization_steps + 1
)
return tf.group([update_mean, update_variance, update_norm_step])
@staticmethod
def create_vector_observation_encoder(

m_size = memory_in.get_shape().as_list()[1]
lstm_input_state = tf.reshape(input_state, shape=[-1, sequence_length, s_size])
memory_in = tf.reshape(memory_in[:, :], [-1, m_size])
_half_point = int(m_size / 2)
half_point = int(m_size / 2)
rnn_cell = tf.contrib.rnn.BasicLSTMCell(_half_point)
rnn_cell = tf.contrib.rnn.BasicLSTMCell(half_point)
memory_in[:, :_half_point], memory_in[:, _half_point:]
memory_in[:, :half_point], memory_in[:, half_point:]
recurrent_output = tf.reshape(recurrent_output, shape=[-1, _half_point])
recurrent_output = tf.reshape(recurrent_output, shape=[-1, half_point])
def create_value_heads(self, stream_names, hidden_input):
"""
Creates one value estimator head for each reward signal in stream_names.
Also creates the node corresponding to the mean of all the value heads in self.value.
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal.
:param stream_names: The list of reward signal names
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
"""
self.value_heads = {}
for name in stream_names:
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name))
self.value_heads[name] = value
self.value = tf.reduce_mean(list(self.value_heads.values()), 0)
def create_cc_actor_critic(self, h_size, num_layers):
"""
Creates Continuous control actor-critic model.

kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01),
)
log_sigma_sq = tf.get_variable(
self.log_sigma_sq = tf.get_variable(
"log_sigma_squared",
[self.act_size[0]],
dtype=tf.float32,

sigma_sq = tf.exp(log_sigma_sq)
sigma_sq = tf.exp(self.log_sigma_sq)
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"

all_probs = (
-0.5 * tf.square(tf.stop_gradient(self.output_pre) - mu) / sigma_sq
- 0.5 * tf.log(2.0 * np.pi)
- 0.5 * log_sigma_sq
- 0.5 * self.log_sigma_sq
self.entropy = 0.5 * tf.reduce_mean(tf.log(2 * np.pi * np.e) + log_sigma_sq)
self.entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + self.log_sigma_sq
)
value = tf.layers.dense(hidden_value, 1, activation=None)
self.value = tf.identity(value, name="value_estimate")
self.create_value_heads(self.stream_names, hidden_value)
self.all_old_log_probs = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="old_probabilities"

self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
value = tf.layers.dense(hidden, 1, activation=None)
self.value = tf.identity(value, name="value_estimate")
self.create_value_heads(self.stream_names, hidden)
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"

2
ml-agents/mlagents/trainers/policy.py


run_out = dict(zip(list(out_dict.keys()), network_out))
return run_out
def _fill_eval_dict(self, feed_dict, brain_info):
def fill_eval_dict(self, feed_dict, brain_info):
for i, _ in enumerate(brain_info.visual_observations):
feed_dict[self.model.visual_in[i]] = brain_info.visual_observations[i]
if self.use_vec_obs:

222
ml-agents/mlagents/trainers/ppo/models.py


use_recurrent=False,
num_layers=2,
m_size=None,
use_curiosity=False,
curiosity_strength=0.01,
curiosity_enc_size=128,
stream_names=None,
):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the

:param h_size: Size of hidden layers
:param epsilon: Value for policy-divergence threshold.
:param beta: Strength of entropy regularization.
:return: a sub-class of PPOAgent tailored to the environment.
:param seed: Seed to use for initialization of model.
:param stream_names: List of names of value streams. Usually, a list of the Reward Signals being used.
:return: a sub-class of PPOAgent tailored to the environment.
LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed)
self.use_curiosity = use_curiosity
LearningModel.__init__(
self, m_size, normalize, use_recurrent, brain, seed, stream_names
)
if num_layers < 1:
num_layers = 1
self.last_reward, self.new_reward, self.update_reward = (

self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy
else:
self.create_dc_actor_critic(h_size, num_layers)
if self.use_curiosity:
self.curiosity_enc_size = curiosity_enc_size
self.curiosity_strength = curiosity_strength
encoded_state, encoded_next_state = self.create_curiosity_encoders()
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_ppo_optimizer(
self.create_losses(
self.value,
self.value_heads,
self.entropy,
beta,
epsilon,

update_reward = tf.assign(last_reward, new_reward)
return last_reward, new_reward, update_reward
def create_curiosity_encoders(self):
"""
Creates state encoders for current and future observations.
Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction
See https://arxiv.org/abs/1705.05363 for more details.
:return: current and future state encoder tensors.
"""
encoded_state_list = []
encoded_next_state_list = []
if self.vis_obs_size > 0:
self.next_visual_in = []
visual_encoders = []
next_visual_encoders = []
for i in range(self.vis_obs_size):
# Create input ops for next (t+1) visual observations.
next_visual_input = self.create_visual_input(
self.brain.camera_resolutions[i],
name="next_visual_observation_" + str(i),
)
self.next_visual_in.append(next_visual_input)
# Create the encoder ops for current and next visual input. Not that these encoders are siamese.
encoded_visual = self.create_visual_observation_encoder(
self.visual_in[i],
self.curiosity_enc_size,
self.swish,
1,
"stream_{}_visual_obs_encoder".format(i),
False,
)
encoded_next_visual = self.create_visual_observation_encoder(
self.next_visual_in[i],
self.curiosity_enc_size,
self.swish,
1,
"stream_{}_visual_obs_encoder".format(i),
True,
)
visual_encoders.append(encoded_visual)
next_visual_encoders.append(encoded_next_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
hidden_next_visual = tf.concat(next_visual_encoders, axis=1)
encoded_state_list.append(hidden_visual)
encoded_next_state_list.append(hidden_next_visual)
if self.vec_obs_size > 0:
# Create the encoder ops for current and next vector input. Not that these encoders are siamese.
# Create input op for next (t+1) vector observation.
self.next_vector_in = tf.placeholder(
shape=[None, self.vec_obs_size],
dtype=tf.float32,
name="next_vector_observation",
)
encoded_vector_obs = self.create_vector_observation_encoder(
self.vector_in,
self.curiosity_enc_size,
self.swish,
2,
"vector_obs_encoder",
False,
)
encoded_next_vector_obs = self.create_vector_observation_encoder(
self.next_vector_in,
self.curiosity_enc_size,
self.swish,
2,
"vector_obs_encoder",
True,
)
encoded_state_list.append(encoded_vector_obs)
encoded_next_state_list.append(encoded_next_vector_obs)
encoded_state = tf.concat(encoded_state_list, axis=1)
encoded_next_state = tf.concat(encoded_next_state_list, axis=1)
return encoded_state, encoded_next_state
def create_inverse_model(self, encoded_state, encoded_next_state):
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=self.swish)
if self.brain.vector_action_space_type == "continuous":
pred_action = tf.layers.dense(hidden, self.act_size[0], activation=None)
squared_difference = tf.reduce_sum(
tf.squared_difference(pred_action, self.selected_actions), axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.mask, 2)[1]
)
else:
pred_action = tf.concat(
[
tf.layers.dense(hidden, self.act_size[i], activation=tf.nn.softmax)
for i in range(len(self.act_size))
],
axis=1,
)
cross_entropy = tf.reduce_sum(
-tf.log(pred_action + 1e-10) * self.selected_actions, axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(cross_entropy, self.mask, 2)[1]
)
def create_forward_model(self, encoded_state, encoded_next_state):
"""
Creates forward model TensorFlow ops for Curiosity module.
Predicts encoded future state based on encoded current state and given action.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, self.selected_actions], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=self.swish)
# We compare against the concatenation of all observation streams, hence `self.vis_obs_size + int(self.vec_obs_size > 0)`.
pred_next_state = tf.layers.dense(
hidden,
self.curiosity_enc_size * (self.vis_obs_size + int(self.vec_obs_size > 0)),
activation=None,
)
squared_difference = 0.5 * tf.reduce_sum(
tf.squared_difference(pred_next_state, encoded_next_state), axis=1
)
self.intrinsic_reward = tf.clip_by_value(
self.curiosity_strength * squared_difference, 0, 1
)
self.forward_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.mask, 2)[1]
)
def create_ppo_optimizer(
self, probs, old_probs, value, entropy, beta, epsilon, lr, max_step
def create_losses(
self, probs, old_probs, value_heads, entropy, beta, epsilon, lr, max_step
:param value: Current value estimate
:param value_heads: Value estimate tensors from each value stream
:param beta: Entropy regularization strength
:param entropy: Current policy entropy
:param epsilon: Value for policy-divergence threshold

self.returns_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="discounted_rewards"
)
self.returns_holders = {}
self.old_values = {}
for name in value_heads.keys():
returns_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_returns".format(name)
)
old_value = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_value_estimate".format(name)
)
self.returns_holders[name] = returns_holder
self.old_values[name] = old_value
self.advantage = tf.placeholder(
shape=[None, 1], dtype=tf.float32, name="advantages"
)

self.old_value = tf.placeholder(
shape=[None], dtype=tf.float32, name="old_value_estimates"
)
decay_epsilon = tf.train.polynomial_decay(
epsilon, self.global_step, max_step, 0.1, power=1.0
)

optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
clipped_value_estimate = self.old_value + tf.clip_by_value(
tf.reduce_sum(value, axis=1) - self.old_value, -decay_epsilon, decay_epsilon
)
value_losses = []
for name, head in value_heads.items():
clipped_value_estimate = self.old_values[name] + tf.clip_by_value(
tf.reduce_sum(head, axis=1) - self.old_values[name],
-decay_epsilon,
decay_epsilon,
)
v_opt_a = tf.squared_difference(
self.returns_holders[name], tf.reduce_sum(head, axis=1)
)
v_opt_b = tf.squared_difference(
self.returns_holders[name], clipped_value_estimate
)
value_loss = tf.reduce_mean(
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]
)
value_losses.append(value_loss)
self.value_loss = tf.reduce_mean(value_losses)
v_opt_a = tf.squared_difference(
self.returns_holder, tf.reduce_sum(value, axis=1)
)
v_opt_b = tf.squared_difference(self.returns_holder, clipped_value_estimate)
self.value_loss = tf.reduce_mean(
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]
)
# Here we calculate PPO policy loss. In continuous control this is done independently for each action gaussian
# and then averaged together. This provides significantly better performance than treating the probability
# as an average of probabilities, or as a joint probability.
r_theta = tf.exp(probs - old_probs)
p_opt_a = r_theta * self.advantage
p_opt_b = (

* tf.reduce_mean(tf.dynamic_partition(entropy, self.mask, 2)[1])
)
if self.use_curiosity:
self.loss += 10 * (0.2 * self.forward_loss + 0.8 * self.inverse_loss)
def create_ppo_optimizer(self):
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.update_batch = optimizer.minimize(self.loss)

138
ml-agents/mlagents/trainers/ppo/policy.py


import logging
import numpy as np
from mlagents.trainers import BrainInfo, ActionInfo
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
logger = logging.getLogger("mlagents.trainers")

:param load: Whether a pre-trained model will be loaded or a new one created.
"""
super().__init__(seed, brain, trainer_params)
self.has_updated = False
self.use_curiosity = bool(trainer_params["use_curiosity"])
reward_signal_configs = trainer_params["reward_signals"]
self.reward_signals = {}
with self.graph.as_default():
self.model = PPOModel(
brain,

use_recurrent=trainer_params["use_recurrent"],
num_layers=int(trainer_params["num_layers"]),
m_size=self.m_size,
use_curiosity=bool(trainer_params["use_curiosity"]),
curiosity_strength=float(trainer_params["curiosity_strength"]),
curiosity_enc_size=float(trainer_params["curiosity_enc_size"]),
stream_names=list(reward_signal_configs.keys()),
self.model.create_ppo_optimizer()
# Create reward signals
for reward_signal, config in reward_signal_configs.items():
self.reward_signals[reward_signal] = create_reward_signal(
self, reward_signal, config
)
if load:
self._load_graph()

self.inference_dict = {
"action": self.model.output,
"log_probs": self.model.all_log_probs,
"value": self.model.value,
"value": self.model.value_heads,
"entropy": self.model.entropy,
"learning_rate": self.model.learning_rate,
}

self.inference_dict["memory_out"] = self.model.memory_out
if is_training and self.use_vec_obs and trainer_params["normalize"]:
self.inference_dict["update_mean"] = self.model.update_mean
self.inference_dict["update_variance"] = self.model.update_variance
if (
is_training
and self.use_vec_obs
and trainer_params["normalize"]
and not load
):
self.inference_dict["update_mean"] = self.model.update_normalization
self.total_policy_loss = self.model.policy_loss
"policy_loss": self.model.policy_loss,
"policy_loss": self.total_policy_loss,
if self.use_curiosity:
self.update_dict["forward_loss"] = self.model.forward_loss
self.update_dict["inverse_loss"] = self.model.inverse_loss
def evaluate(self, brain_info):
"""

size=(len(brain_info.vector_observations), self.model.act_size[0])
)
feed_dict[self.model.epsilon] = epsilon
feed_dict = self._fill_eval_dict(feed_dict, brain_info)
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
run_out = self._execute_model(feed_dict, self.inference_dict)
if self.use_continuous_act:
run_out["random_normal_epsilon"] = epsilon

self.model.batch_size: num_sequences,
self.model.sequence_length: self.sequence_length,
self.model.mask_input: mini_batch["masks"].flatten(),
self.model.returns_holder: mini_batch["discounted_returns"].flatten(),
self.model.old_value: mini_batch["value_estimates"].flatten(),
for name in self.reward_signals:
feed_dict[self.model.returns_holders[name]] = mini_batch[
"{}_returns".format(name)
].flatten()
feed_dict[self.model.old_values[name]] = mini_batch[
"{}_value_estimates".format(name)
].flatten()
if self.use_continuous_act:
feed_dict[self.model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, self.model.act_size[0]]

feed_dict[self.model.vector_in] = mini_batch["vector_obs"].reshape(
[-1, self.vec_obs_size]
)
if self.use_curiosity:
feed_dict[self.model.next_vector_in] = mini_batch[
"next_vector_in"
].reshape([-1, self.vec_obs_size])
if self.model.vis_obs_size > 0:
for i, _ in enumerate(self.model.visual_in):
_obs = mini_batch["visual_obs%d" % i]

else:
feed_dict[self.model.visual_in[i]] = _obs
if self.use_curiosity:
for i, _ in enumerate(self.model.visual_in):
_obs = mini_batch["next_visual_obs%d" % i]
if self.sequence_length > 1 and self.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.next_visual_in[i]] = _obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.model.next_visual_in[i]] = _obs
self.has_updated = True
def get_intrinsic_rewards(self, curr_info, next_info):
"""
Generates intrinsic reward used for Curiosity-based training.
:BrainInfo curr_info: Current BrainInfo.
:BrainInfo next_info: Next BrainInfo.
:return: Intrinsic rewards for all agents.
"""
if self.use_curiosity:
if len(curr_info.agents) == 0:
return []
feed_dict = {
self.model.batch_size: len(next_info.vector_observations),
self.model.sequence_length: 1,
}
if self.use_continuous_act:
feed_dict[
self.model.selected_actions
] = next_info.previous_vector_actions
else:
feed_dict[self.model.action_holder] = next_info.previous_vector_actions
for i in range(self.model.vis_obs_size):
feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i]
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[
i
]
if self.use_vec_obs:
feed_dict[self.model.vector_in] = curr_info.vector_observations
feed_dict[self.model.next_vector_in] = next_info.vector_observations
if self.use_recurrent:
if curr_info.memories.shape[1] == 0:
curr_info.memories = self.make_empty_memory(len(curr_info.agents))
feed_dict[self.model.memory_in] = curr_info.memories
intrinsic_rewards = self.sess.run(
self.model.intrinsic_reward, feed_dict=feed_dict
) * float(self.has_updated)
return intrinsic_rewards
else:
return None
def get_value_estimate(self, brain_info, idx):
def get_value_estimates(self, brain_info, idx):
:return: Value estimate.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
"""
feed_dict = {self.model.batch_size: 1, self.model.sequence_length: 1}
for i in range(len(brain_info.visual_observations)):

feed_dict[self.model.prev_action] = brain_info.previous_vector_actions[
idx
].reshape([-1, len(self.model.act_size)])
value_estimate = self.sess.run(self.model.value, feed_dict)
return value_estimate
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
return value_estimates
def get_action(self, brain_info: BrainInfo) -> ActionInfo:
"""
Decides actions given observations information, and takes them in environment.
:param brain_info: A dictionary of brain names and BrainInfo from environment.
:return: an ActionInfo containing action, memories, values and an object
to be passed to add experiences
"""
if len(brain_info.agents) == 0:
return ActionInfo([], [], [], None, None)
run_out = self.evaluate(brain_info)
mean_values = np.mean(
np.array(list(run_out.get("value").values())), axis=0
).flatten()
return ActionInfo(
action=run_out.get("action"),
memory=run_out.get("memory_out"),
text=None,
value=mean_values,
outputs=run_out,
)
def get_last_reward(self):
"""

231
ml-agents/mlagents/trainers/ppo/trainer.py


# # Unity ML-Agents Toolkit
# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described (https://arxiv.org/abs/1707.06347).
# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347
from collections import deque
from collections import deque, defaultdict
from typing import Any, List
import numpy as np

from mlagents.trainers.buffer import Buffer
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.action_info import ActionInfoOutputs
logger = logging.getLogger("mlagents.trainers")

"beta",
"buffer_size",
"epsilon",
"gamma",
"hidden_units",
"lambd",
"learning_rate",

"use_recurrent",
"summary_path",
"memory_size",
"use_curiosity",
"curiosity_strength",
"curiosity_enc_size",
"reward_signals",
self.check_param_keys()
self.check_param_keys()
self.use_curiosity = bool(trainer_parameters["use_curiosity"])
# Make sure we have at least one reward_signal
if not self.trainer_parameters["reward_signals"]:
raise UnityTrainerException(
"No reward signals were defined. At least one must be used with {}.".format(
self.__class__.__name__
)
)
stats = {
"Environment/Cumulative Reward": [],
"Environment/Episode Length": [],
"Policy/Value Estimate": [],
"Policy/Entropy": [],
"Losses/Value Loss": [],
"Losses/Policy Loss": [],
"Policy/Learning Rate": [],
}
if self.use_curiosity:
stats["Losses/Forward Loss"] = []
stats["Losses/Inverse Loss"] = []
stats["Policy/Curiosity Reward"] = []
self.intrinsic_rewards = {}
stats = defaultdict(list)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.collected_rewards = {"environment": {}}
for _reward_signal in self.policy.reward_signals.keys():
self.collected_rewards[_reward_signal] = {}
self.cumulative_rewards = {}
return """Hyperparameters for the PPO Trainer of brain {0}: \n{1}""".format(
return """Hyperparameters for the {0} of brain {1}: \n{2}""".format(
self.__class__.__name__,
"\n".join(
[
"\t{0}:\t{1}".format(x, self.trainer_parameters[x])
for x in self.param_keys
]
),
self.dict_to_str(self.trainer_parameters, 0),
)
@property

"""
Increment the step count of the trainer and Updates the last reward
"""
if len(self.stats["Environment/Cumulative Reward"]) > 0:
if self.stats["Environment/Cumulative Reward"]:
mean_reward = np.mean(self.stats["Environment/Cumulative Reward"])
self.policy.update_reward(mean_reward)
self.policy.increment_step()

"""
Constructs a BrainInfo which contains the most recent previous experiences for all agents info
Constructs a BrainInfo which contains the most recent previous experiences for all agents
which correspond to the agents in a provided next_info.
:BrainInfo next_info: A t+1 BrainInfo.
:return: curr_info: Reconstructed BrainInfo to match agents of next_info.

"""
self.trainer_metrics.start_experience_collection_timer()
if take_action_outputs:
self.stats["Policy/Value Estimate"].append(
take_action_outputs["value"].mean()
)
for name, signal in self.policy.reward_signals.items():
self.stats[signal.value_name].append(
np.mean(take_action_outputs["value"][name])
)
curr_info = curr_all_info[self.brain_name]
next_info = next_all_info[self.brain_name]

else:
curr_to_use = curr_info
intrinsic_rewards = self.policy.get_intrinsic_rewards(curr_to_use, next_info)
tmp_rewards_dict = {}
for name, signal in self.policy.reward_signals.items():
tmp_rewards_dict[name] = signal.evaluate(curr_to_use, next_info)
for agent_id in next_info.agents:
stored_info = self.training_buffer[agent_id].last_brain_info

stored_info.action_masks[idx], padding_value=1
)
a_dist = stored_take_action_outputs["log_probs"]
# value is a dictionary from name of reward to value estimate of the value head
value = stored_take_action_outputs["value"]
self.training_buffer[agent_id]["actions"].append(actions[idx])
self.training_buffer[agent_id]["prev_action"].append(

if self.use_curiosity:
self.training_buffer[agent_id]["rewards"].append(
next_info.rewards[next_idx] + intrinsic_rewards[next_idx]
)
else:
self.training_buffer[agent_id]["rewards"].append(
next_info.rewards[next_idx]
)
self.training_buffer[agent_id]["action_probs"].append(a_dist[idx])
self.training_buffer[agent_id]["value_estimates"].append(
value[idx][0]
self.training_buffer[agent_id]["done"].append(
next_info.local_done[next_idx]
if agent_id not in self.cumulative_rewards:
self.cumulative_rewards[agent_id] = 0
self.cumulative_rewards[agent_id] += next_info.rewards[next_idx]
if self.use_curiosity:
if agent_id not in self.intrinsic_rewards:
self.intrinsic_rewards[agent_id] = 0
self.intrinsic_rewards[agent_id] += intrinsic_rewards[next_idx]
for name, reward_result in tmp_rewards_dict.items():
# 0 because we use the scaled reward to train the agent
self.training_buffer[agent_id][
"{}_rewards".format(name)
].append(reward_result.scaled_reward[next_idx])
self.training_buffer[agent_id][
"{}_value_estimates".format(name)
].append(value[name][idx][0])
self.training_buffer[agent_id]["action_probs"].append(a_dist[idx])
for name, rewards in self.collected_rewards.items():
if agent_id not in rewards:
rewards[agent_id] = 0
if name == "environment":
# Report the reward from the environment
rewards[agent_id] += np.array(next_info.rewards)[next_idx]
else:
# Report the reward signals
rewards[agent_id] += tmp_rewards_dict[name].scaled_reward[
next_idx
]
if not next_info.local_done[next_idx]:
if agent_id not in self.episode_steps:
self.episode_steps[agent_id] = 0

:param current_info: Dictionary of all current brains and corresponding BrainInfo.
:param new_info: Dictionary of all next brains and corresponding BrainInfo.
"""
self.trainer_metrics.start_experience_collection_timer()
info = new_info[self.brain_name]
for l in range(len(info.agents)):
agent_actions = self.training_buffer[info.agents[l]]["actions"]

) and len(agent_actions) > 0:
agent_id = info.agents[l]
if info.max_reached[l]:
bootstrapping_info = self.training_buffer[agent_id].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
else:
bootstrapping_info = info
idx = l
value_next = self.policy.get_value_estimates(bootstrapping_info, idx)
value_next = 0.0
else:
if info.max_reached[l]:
bootstrapping_info = self.training_buffer[
agent_id
].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
else:
bootstrapping_info = info
idx = l
value_next = self.policy.get_value_estimate(bootstrapping_info, idx)
value_next["extrinsic"] = 0.0
tmp_advantages = []
tmp_returns = []
for name in self.policy.reward_signals:
bootstrap_value = value_next[name]
self.training_buffer[agent_id]["advantages"].set(
get_gae(
rewards=self.training_buffer[agent_id]["rewards"].get_batch(),
value_estimates=self.training_buffer[agent_id][
"value_estimates"
].get_batch(),
value_next=value_next,
gamma=self.trainer_parameters["gamma"],
local_rewards = self.training_buffer[agent_id][
"{}_rewards".format(name)
].get_batch()
local_value_estimates = self.training_buffer[agent_id][
"{}_value_estimates".format(name)
].get_batch()
local_advantage = get_gae(
rewards=local_rewards,
value_estimates=local_value_estimates,
value_next=bootstrap_value,
gamma=self.policy.reward_signals[name].gamma,
)
self.training_buffer[agent_id]["discounted_returns"].set(
self.training_buffer[agent_id]["advantages"].get_batch()
+ self.training_buffer[agent_id]["value_estimates"].get_batch()
)
local_return = local_advantage + local_value_estimates
# This is later use as target for the different value estimates
self.training_buffer[agent_id]["{}_returns".format(name)].set(
local_return
)
self.training_buffer[agent_id]["{}_advantage".format(name)].set(
local_advantage
)
tmp_advantages.append(local_advantage)
tmp_returns.append(local_return)
global_advantages = list(np.mean(np.array(tmp_advantages), axis=0))
global_returns = list(np.mean(np.array(tmp_returns), axis=0))
self.training_buffer[agent_id]["advantages"].set(global_advantages)
self.training_buffer[agent_id]["discounted_returns"].set(global_returns)
self.training_buffer.append_update_buffer(
agent_id,

self.training_buffer[agent_id].reset_agent()
if info.local_done[l]:
self.cumulative_returns_since_policy_update.append(
self.cumulative_rewards.get(agent_id, 0)
)
self.stats["Environment/Cumulative Reward"].append(
self.cumulative_rewards.get(agent_id, 0)
)
self.reward_buffer.appendleft(
self.cumulative_rewards.get(agent_id, 0)
)
self.cumulative_rewards[agent_id] = 0
if self.use_curiosity:
self.stats["Policy/Curiosity Reward"].append(
self.intrinsic_rewards.get(agent_id, 0)
)
self.intrinsic_rewards[agent_id] = 0
self.trainer_metrics.end_experience_collection_timer()
for name, rewards in self.collected_rewards.items():
if name == "environment":
self.cumulative_returns_since_policy_update.append(
rewards.get(agent_id, 0)
)
self.stats["Environment/Cumulative Reward"].append(
rewards.get(agent_id, 0)
)
rewards[agent_id] = 0
self.reward_buffer.appendleft(rewards.get(agent_id, 0))
else:
self.stats[
self.policy.reward_signals[name].stat_name
].append(rewards.get(agent_id, 0))
rewards[agent_id] = 0
def end_episode(self):
"""

self.training_buffer.reset_local_buffers()
for agent_id in self.cumulative_rewards:
self.cumulative_rewards[agent_id] = 0
if self.use_curiosity:
for agent_id in self.intrinsic_rewards:
self.intrinsic_rewards[agent_id] = 0
for rewards in self.collected_rewards.values():
for agent_id in rewards:
rewards[agent_id] = 0
def is_ready_update(self):
"""

def update_policy(self):
"""
Uses demonstration_buffer to update the policy.
The reward signal generators must be updated in this method at their own pace.
"""
self.trainer_metrics.start_policy_update_timer(
number_experiences=len(self.training_buffer.update_buffer["actions"]),

)
value_total.append(run_out["value_loss"])
policy_total.append(np.abs(run_out["policy_loss"]))
if self.use_curiosity:
inverse_total.append(run_out["inverse_loss"])
forward_total.append(run_out["forward_loss"])
if self.use_curiosity:
self.stats["Losses/Forward Loss"].append(np.mean(forward_total))
self.stats["Losses/Inverse Loss"].append(np.mean(inverse_total))
for _, reward_signal in self.policy.reward_signals.items():
update_stats = reward_signal.update(
self.training_buffer.update_buffer, n_sequences
)
for stat, val in update_stats.items():
self.stats[stat].append(val)
self.training_buffer.reset_update_buffer()
self.trainer_metrics.end_policy_update()

158
ml-agents/mlagents/trainers/tests/test_ppo.py


beta: 5.0e-3
buffer_size: 512
epsilon: 0.2
gamma: 0.99
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4

summary_freq: 1000
use_recurrent: false
memory_size: 8
use_curiosity: false
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)

model.memory_in: np.zeros((1, memory_size)),
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.epsilon: np.array([[0, 1]]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_model_dc_vector_curio(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=0
)
env = UnityEnvironment(" ")
model = PPOModel(env.brains["RealFakeBrain"], use_curiosity=True)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.intrinsic_reward,
]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.next_vector_in: np.array(
[[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]
),
model.action_holder: [[0], [0]],
model.action_masks: np.ones([2, 2]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_model_cc_vector_curio(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
model = PPOModel(env.brains["RealFakeBrain"], use_curiosity=True)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.intrinsic_reward,
]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.next_vector_in: np.array(
[[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]
),
model.output: [[0.0, 0.0], [0.0, 0.0]],
model.epsilon: np.array([[0, 1], [2, 3]]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_model_dc_visual_curio(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=2
)
env = UnityEnvironment(" ")
model = PPOModel(env.brains["RealFakeBrain"], use_curiosity=True)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.intrinsic_reward,
]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.next_vector_in: np.array(
[[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]
),
model.action_holder: [[0], [0]],
model.visual_in[0]: np.ones([2, 40, 30, 3]),
model.visual_in[1]: np.ones([2, 40, 30, 3]),
model.next_visual_in[0]: np.ones([2, 40, 30, 3]),
model.next_visual_in[1]: np.ones([2, 40, 30, 3]),
model.action_masks: np.ones([2, 2]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_model_cc_visual_curio(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=2
)
env = UnityEnvironment(" ")
model = PPOModel(env.brains["RealFakeBrain"], use_curiosity=True)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.intrinsic_reward,
]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.next_vector_in: np.array(
[[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]
),
model.output: [[0.0, 0.0], [0.0, 0.0]],
model.visual_in[0]: np.ones([2, 40, 30, 3]),
model.visual_in[1]: np.ones([2, 40, 30, 3]),
model.next_visual_in[0]: np.ones([2, 40, 30, 3]),
model.next_visual_in[1]: np.ones([2, 40, 30, 3]),
model.epsilon: np.array([[0, 1], [2, 3]]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()

22
ml-agents/mlagents/trainers/trainer.py


"brain {2}.".format(k, self.__class__, self.brain_name)
)
def dict_to_str(self, param_dict, num_tabs):
"""
Takes a parameter dictionary and converts it to a human-readable string.
Recurses if there are multiple levels of dict. Used to print out hyperaparameters.
param: param_dict: A Dictionary of key, value parameters.
return: A string version of this dictionary.
"""
if not isinstance(param_dict, dict):
return param_dict
else:
append_newline = "\n" if num_tabs > 0 else ""
return append_newline + "\n".join(
[
"\t"
+ " " * num_tabs
+ "{0}:\t{1}".format(
x, self.dict_to_str(param_dict[x], num_tabs + 1)
)
for x in param_dict
]
)
@property
def parameters(self):
"""

114
docs/Training-RewardSignals.md


# Reward Signals
In reinforcement learning, the end goal for the Agent is to discover a behavior (a Policy)
that maximizes a reward. Typically, a reward is defined by your environment, and corresponds
to reaching some goal. These are what we refer to as "extrinsic" rewards, as they are defined
external of the learning algorithm.
Rewards, however, can be defined outside of the enviroment as well, to encourage the agent to
behave in certain ways, or to aid the learning of the true extrinsic reward. We refer to these
rewards as "intrinsic" reward signals. The total reward that the agent will learn to maximize can
be a mix of extrinsic and intrinsic reward signals.
ML-Agents allows reward signals to be defined in a modular way, and we provide three reward
signals that can the mixed and matched to help shape your agent's behavior. The `extrinsic` Reward
Signal represents the rewards defined in your environment, and is enabled by default.
The `curiosity` reward signal helps your agent explore when extrinsic rewards are sparse.
## Enabling Reward Signals
Reward signals, like other hyperparameters, are defined in the trainer config `.yaml` file. An
example is provided in `config/trainer_config.yaml`. To enable a reward signal, add it to the
`reward_signals:` section under the brain name. For instance, to enable the extrinsic signal
in addition to a small curiosity reward, you would define your `reward_signals` as follows:
```yaml
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:
strength: 0.01
gamma: 0.99
encoding_size: 128
```
Each reward signal should define at least two parameters, `strength` and `gamma`, in addition
to any class-specific hyperparameters. Note that to remove a reward signal, you should delete
its entry entirely from `reward_signals`. At least one reward signal should be left defined
at all times.
## Reward Signal Types
### The Extrinsic Reward Signal
The `extrinsic` reward signal is simply the reward given by the
[environment](Learning-Environment-Design.md). Remove it to force the agent
to ignore the environment reward.
#### Strength
`strength` is the factor by which to multiply the raw
reward. Typical ranges will vary depending on the reward signal.
Typical Range: `1.0`
#### Gamma
`gamma` corresponds to the discount factor for future rewards. This can be
thought of as how far into the future the agent should care about possible
rewards. In situations when the agent should be acting in the present in order
to prepare for rewards in the distant future, this value should be large. In
cases when rewards are more immediate, it can be smaller.
Typical Range: `0.8` - `0.995`
### The Curiosity Reward Signal
The `curiosity` Reward Signal enables the Intrinsic Curiosity Module. This is an implementation
of the approach described in "Curiosity-driven Exploration by Self-supervised Prediction"
by Pathak, et al. It trains two networks:
* an inverse model, which takes the current and next obersvation of the agent, encodes them, and
uses the encoding to predict the action that was taken between the observations
* a forward model, which takes the encoded current obseravation and action, and predicts the
next encoded observation.
The loss of the forward model (the difference between the predicted and actual encoded observations) is used as the intrinsic reward, so the more surprised the model is, the larger the reward will be.
For more information, see
* https://arxiv.org/abs/1705.05363
* https://pathak22.github.io/noreward-rl/
* https://blogs.unity3d.com/2018/06/26/solving-sparse-reward-tasks-with-curiosity/
#### Strength
In this case, `strength` corresponds to the magnitude of the curiosity reward generated
by the intrinsic curiosity module. This should be scaled in order to ensure it is large enough
to not be overwhelmed by extrinsic reward signals in the environment.
Likewise it should not be too large to overwhelm the extrinsic reward signal.
Typical Range: `0.001` - `0.1`
#### Gamma
`gamma` corresponds to the discount factor for future rewards.
Typical Range: `0.8` - `0.995`
#### Encoding Size
`encoding_size` corresponds to the size of the encoding used by the intrinsic curiosity model.
This value should be small enough to encourage the ICM to compress the original
observation, but also not too small to prevent it from learning to differentiate between
demonstrated and actual behavior.
Default Value: 64
Typical Range: `64` - `256`
#### Learning Rate
`learning_rate` is the learning rate used to update the intrinsic curiosity module.
This should typically be decreased if training is unstable, and the curiosity loss is unstable.
Default Value: `3e-4`
Typical Range: `1e-5` - `1e-3`

92
ml-agents/mlagents/trainers/tests/mock_brain.py


import unittest.mock as mock
import pytest
import numpy as np
def create_mock_brainparams(
number_visual_observations=0,
num_stacked_vector_observations=1,
vector_action_space_type="continuous",
vector_observation_space_size=3,
vector_action_space_size=None,
):
"""
Creates a mock BrainParameters object with parameters.
"""
# Avoid using mutable object as default param
if vector_action_space_size is None:
vector_action_space_size = [2]
mock_brain = mock.Mock()
mock_brain.return_value.number_visual_observations = number_visual_observations
mock_brain.return_value.num_stacked_vector_observations = (
num_stacked_vector_observations
)
mock_brain.return_value.vector_action_space_type = vector_action_space_type
mock_brain.return_value.vector_observation_space_size = (
vector_observation_space_size
)
camrez = {"blackAndWhite": False, "height": 84, "width": 84}
mock_brain.return_value.camera_resolutions = [camrez] * number_visual_observations
mock_brain.return_value.vector_action_space_size = vector_action_space_size
return mock_brain()
def create_mock_braininfo(
num_agents=1,
num_vector_observations=0,
num_vis_observations=0,
num_vector_acts=2,
discrete=False,
):
"""
Creates a mock BrainInfo with observations. Imitates constant
vector/visual observations, rewards, dones, and agents.
:int num_agents: Number of "agents" to imitate in your BrainInfo values.
:int num_vector_observations: Number of "observations" in your observation space
:int num_vis_observations: Number of "observations" in your observation space
:int num_vector_acts: Number of actions in your action space
:bool discrete: Whether or not action space is discrete
"""
mock_braininfo = mock.Mock()
mock_braininfo.return_value.visual_observations = num_vis_observations * [
np.ones((num_agents, 84, 84, 3))
]
mock_braininfo.return_value.vector_observations = np.array(
num_agents * [num_vector_observations * [1]]
)
if discrete:
mock_braininfo.return_value.previous_vector_actions = np.array(
num_agents * [1 * [0.5]]
)
mock_braininfo.return_value.action_masks = np.array(
num_agents * [num_vector_acts * [1.0]]
)
else:
mock_braininfo.return_value.previous_vector_actions = np.array(
num_agents * [num_vector_acts * [0.5]]
)
mock_braininfo.return_value.memories = np.ones((num_agents, 8))
mock_braininfo.return_value.rewards = num_agents * [1.0]
mock_braininfo.return_value.local_done = num_agents * [False]
mock_braininfo.return_value.text_observations = num_agents * [""]
mock_braininfo.return_value.agents = range(0, num_agents)
return mock_braininfo()
def setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo):
"""
Takes a mock UnityEnvironment and adds the appropriate properties, defined by the mock
BrainParameters and BrainInfo.
:Mock mock_env: A mock UnityEnvironment, usually empty.
:Mock mock_brain: A mock Brain object that specifies the params of this environment.
:Mock mock_braininfo: A mock BrainInfo object that will be returned at each step and reset.
"""
mock_env.return_value.academy_name = "MockAcademy"
mock_env.return_value.brains = {"MockBrain": mock_brain}
mock_env.return_value.external_brain_names = ["MockBrain"]
mock_env.return_value.brain_names = ["MockBrain"]
mock_env.return_value.reset.return_value = {"MockBrain": mock_braininfo}
mock_env.return_value.step.return_value = {"MockBrain": mock_braininfo}

156
ml-agents/mlagents/trainers/tests/test_reward_signals.py


import unittest.mock as mock
import pytest
import mlagents.trainers.tests.mock_brain as mb
import numpy as np
import tensorflow as tf
import yaml
import os
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.ppo.trainer import discount_rewards
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.envs import UnityEnvironment
from mlagents.envs.mock_communicator import MockCommunicator
@pytest.fixture
def dummy_config():
return yaml.safe_load(
"""
trainer: ppo
batch_size: 32
beta: 5.0e-3
buffer_size: 512
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
max_steps: 5.0e4
normalize: true
num_epoch: 5
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
use_recurrent: false
memory_size: 8
curiosity_strength: 0.0
curiosity_enc_size: 1
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
@pytest.fixture
def curiosity_dummy_config():
return {"curiosity": {"strength": 0.1, "gamma": 0.9, "encoding_size": 128}}
def create_ppo_policy_mock(
mock_env, dummy_config, reward_signal_config, use_rnn, use_discrete, use_visual
):
if not use_visual:
mock_brain = mb.create_mock_brainparams(
vector_action_space_type="discrete" if use_discrete else "continuous",
vector_action_space_size=[2],
vector_observation_space_size=8,
)
mock_braininfo = mb.create_mock_braininfo(
num_agents=12,
num_vector_observations=8,
num_vector_acts=2,
discrete=use_discrete,
)
else:
mock_brain = mb.create_mock_brainparams(
vector_action_space_type="discrete" if use_discrete else "continuous",
vector_action_space_size=[2],
vector_observation_space_size=0,
number_visual_observations=1,
)
mock_braininfo = mb.create_mock_braininfo(
num_agents=12,
num_vis_observations=1,
num_vector_acts=2,
discrete=use_discrete,
)
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
env = mock_env()
trainer_parameters = dummy_config
model_path = env.brain_names[0]
trainer_parameters["model_path"] = model_path
trainer_parameters["keep_checkpoints"] = 3
trainer_parameters["reward_signals"].update(reward_signal_config)
trainer_parameters["use_recurrent"] = use_rnn
policy = PPOPolicy(0, mock_brain, trainer_parameters, False, False)
return env, policy
@mock.patch("mlagents.envs.UnityEnvironment")
def test_curiosity_cc_evaluate(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, False
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
next_brain_info = env.step()[env.brain_names[0]]
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
brain_info, next_brain_info
)
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
@mock.patch("mlagents.envs.UnityEnvironment")
def test_curiosity_dc_evaluate(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, True, False
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
next_brain_info = env.step()[env.brain_names[0]]
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
brain_info, next_brain_info
)
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
@mock.patch("mlagents.envs.UnityEnvironment")
def test_curiosity_visual_evaluate(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, True
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
next_brain_info = env.step()[env.brain_names[0]]
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
brain_info, next_brain_info
)
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
@mock.patch("mlagents.envs.UnityEnvironment")
def test_curiosity_rnn_evaluate(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, True, False, False
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
next_brain_info = env.step()[env.brain_names[0]]
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
brain_info, next_brain_info
)
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
if __name__ == "__main__":
pytest.main()

0
ml-agents/mlagents/trainers/components/__init__.py

1
ml-agents/mlagents/trainers/components/reward_signals/__init__.py


from .reward_signal import *

1
ml-agents/mlagents/trainers/components/reward_signals/curiosity/__init__.py


from .signal import CuriosityRewardSignal

179
ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py


import tensorflow as tf
from mlagents.trainers.models import LearningModel
class CuriosityModel(object):
def __init__(
self,
policy_model: LearningModel,
encoding_size: int = 128,
learning_rate: float = 3e-4,
):
"""
Creates the curiosity model for the Curiosity reward Generator
:param policy_model: The model being used by the learning policy
:param encoding_size: The size of the encoding for the Curiosity module
:param learning_rate: The learning rate for the curiosity module
"""
self.encoding_size = encoding_size
self.policy_model = policy_model
encoded_state, encoded_next_state = self.create_curiosity_encoders()
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_loss(learning_rate)
def create_curiosity_encoders(self):
"""
Creates state encoders for current and future observations.
Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction
See https://arxiv.org/abs/1705.05363 for more details.
:return: current and future state encoder tensors.
"""
encoded_state_list = []
encoded_next_state_list = []
if self.policy_model.vis_obs_size > 0:
self.next_visual_in = []
visual_encoders = []
next_visual_encoders = []
for i in range(self.policy_model.vis_obs_size):
# Create input ops for next (t+1) visual observations.
next_visual_input = LearningModel.create_visual_input(
self.policy_model.brain.camera_resolutions[i],
name="next_visual_observation_" + str(i),
)
self.next_visual_in.append(next_visual_input)
# Create the encoder ops for current and next visual input.
# Note that these encoders are siamese.
encoded_visual = self.policy_model.create_visual_observation_encoder(
self.policy_model.visual_in[i],
self.encoding_size,
LearningModel.swish,
1,
"stream_{}_visual_obs_encoder".format(i),
False,
)
encoded_next_visual = self.policy_model.create_visual_observation_encoder(
self.next_visual_in[i],
self.encoding_size,
LearningModel.swish,
1,
"stream_{}_visual_obs_encoder".format(i),
True,
)
visual_encoders.append(encoded_visual)
next_visual_encoders.append(encoded_next_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
hidden_next_visual = tf.concat(next_visual_encoders, axis=1)
encoded_state_list.append(hidden_visual)
encoded_next_state_list.append(hidden_next_visual)
if self.policy_model.vec_obs_size > 0:
# Create the encoder ops for current and next vector input.
# Note that these encoders are siamese.
# Create input op for next (t+1) vector observation.
self.next_vector_in = tf.placeholder(
shape=[None, self.policy_model.vec_obs_size],
dtype=tf.float32,
name="next_vector_observation",
)
encoded_vector_obs = self.policy_model.create_vector_observation_encoder(
self.policy_model.vector_in,
self.encoding_size,
LearningModel.swish,
2,
"vector_obs_encoder",
False,
)
encoded_next_vector_obs = self.policy_model.create_vector_observation_encoder(
self.next_vector_in,
self.encoding_size,
LearningModel.swish,
2,
"vector_obs_encoder",
True,
)
encoded_state_list.append(encoded_vector_obs)
encoded_next_state_list.append(encoded_next_vector_obs)
encoded_state = tf.concat(encoded_state_list, axis=1)
encoded_next_state = tf.concat(encoded_next_state_list, axis=1)
return encoded_state, encoded_next_state
def create_inverse_model(self, encoded_state, encoded_next_state):
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=LearningModel.swish)
if self.policy_model.brain.vector_action_space_type == "continuous":
pred_action = tf.layers.dense(
hidden, self.policy_model.act_size[0], activation=None
)
squared_difference = tf.reduce_sum(
tf.squared_difference(pred_action, self.policy_model.selected_actions),
axis=1,
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.policy_model.mask, 2)[1]
)
else:
pred_action = tf.concat(
[
tf.layers.dense(
hidden, self.policy_model.act_size[i], activation=tf.nn.softmax
)
for i in range(len(self.policy_model.act_size))
],
axis=1,
)
cross_entropy = tf.reduce_sum(
-tf.log(pred_action + 1e-10) * self.policy_model.selected_actions,
axis=1,
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(cross_entropy, self.policy_model.mask, 2)[1]
)
def create_forward_model(self, encoded_state, encoded_next_state):
"""
Creates forward model TensorFlow ops for Curiosity module.
Predicts encoded future state based on encoded current state and given action.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat(
[encoded_state, self.policy_model.selected_actions], axis=1
)
hidden = tf.layers.dense(combined_input, 256, activation=LearningModel.swish)
pred_next_state = tf.layers.dense(
hidden,
self.encoding_size
* (
self.policy_model.vis_obs_size + int(self.policy_model.vec_obs_size > 0)
),
activation=None,
)
squared_difference = 0.5 * tf.reduce_sum(
tf.squared_difference(pred_next_state, encoded_next_state), axis=1
)
self.intrinsic_reward = squared_difference
self.forward_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.policy_model.mask, 2)[1]
)
def create_loss(self, learning_rate):
"""
Creates the loss node of the model as well as the update_batch optimizer to update the model.
:param learning_rate: The learning rate for the optimizer.
"""
self.loss = 10 * (0.2 * self.forward_loss + 0.8 * self.inverse_loss)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
self.update_batch = optimizer.minimize(self.loss)

178
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


import numpy as np
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.components.reward_signals.curiosity.model import CuriosityModel
from mlagents.trainers.policy import Policy
class CuriosityRewardSignal(RewardSignal):
def __init__(
self,
policy: Policy,
strength: float,
gamma: float,
encoding_size: int = 128,
learning_rate: float = 3e-4,
num_epoch: int = 3,
):
"""
Creates the Curiosity reward generator
:param policy: The Learning Policy
:param encoding_size: The size of the Curiosity encoding
:param signal_strength: The scaling parameter for the reward. The scaled reward will be the unscaled
reward multiplied by the strength parameter
"""
super().__init__(policy, strength, gamma)
self.model = CuriosityModel(
policy.model, encoding_size=encoding_size, learning_rate=learning_rate
)
self.num_epoch = num_epoch
self.update_dict = {
"forward_loss": self.model.forward_loss,
"inverse_loss": self.model.inverse_loss,
"update": self.model.update_batch,
}
self.has_updated = False
def evaluate(self, current_info, next_info):
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.
:param next_info: The BrainInfo from the next timestep.
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator
"""
if len(current_info.agents) == 0:
return []
feed_dict = {
self.policy.model.batch_size: len(next_info.vector_observations),
self.policy.model.sequence_length: 1,
}
feed_dict = self.policy.fill_eval_dict(feed_dict, brain_info=current_info)
if self.policy.use_continuous_act:
feed_dict[
self.policy.model.selected_actions
] = next_info.previous_vector_actions
else:
feed_dict[
self.policy.model.action_holder
] = next_info.previous_vector_actions
for i in range(self.policy.model.vis_obs_size):
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i]
if self.policy.use_vec_obs:
feed_dict[self.model.next_vector_in] = next_info.vector_observations
if self.policy.use_recurrent:
if current_info.memories.shape[1] == 0:
current_info.memories = self.policy.make_empty_memory(
len(current_info.agents)
)
feed_dict[self.policy.model.memory_in] = current_info.memories
unscaled_reward = self.policy.sess.run(
self.model.intrinsic_reward, feed_dict=feed_dict
)
scaled_reward = np.clip(
unscaled_reward * float(self.has_updated) * self.strength, 0, 1
)
return RewardSignalResult(scaled_reward, unscaled_reward)
@classmethod
def check_config(cls, config_dict):
"""
Checks the config and throw an exception if a hyperparameter is missing. Curiosity requires strength,
gamma, and encoding size at minimum.
"""
param_keys = ["strength", "gamma", "encoding_size"]
super().check_config(config_dict, param_keys)
def update(self, update_buffer, num_sequences):
"""
Updates Curiosity model using training buffer. Divides training buffer into mini batches and performs
gradient descent.
:param update_buffer: Update buffer from which to pull data from.
:param num_sequences: Number of sequences in the update buffer.
:return: Dict of stats that should be reported to Tensorboard.
"""
forward_total, inverse_total = [], []
for _ in range(self.num_epoch):
update_buffer.shuffle()
buffer = update_buffer
for l in range(len(update_buffer["actions"]) // num_sequences):
start = l * num_sequences
end = (l + 1) * num_sequences
run_out_curio = self._update_batch(
buffer.make_mini_batch(start, end), num_sequences
)
inverse_total.append(run_out_curio["inverse_loss"])
forward_total.append(run_out_curio["forward_loss"])
update_stats = {
"Losses/Curiosity Forward Loss": np.mean(forward_total),
"Losses/Curiosity Inverse Loss": np.mean(inverse_total),
}
return update_stats
def _update_batch(self, mini_batch, num_sequences):
"""
Updates model using buffer.
:param num_sequences: Number of trajectories in batch.
:param mini_batch: Experience batch.
:return: Output from update process.
"""
feed_dict = {
self.policy.model.batch_size: num_sequences,
self.policy.model.sequence_length: self.policy.sequence_length,
self.policy.model.mask_input: mini_batch["masks"].flatten(),
self.policy.model.advantage: mini_batch["advantages"].reshape([-1, 1]),
self.policy.model.all_old_log_probs: mini_batch["action_probs"].reshape(
[-1, sum(self.policy.model.act_size)]
),
}
if self.policy.use_continuous_act:
feed_dict[self.policy.model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, self.policy.model.act_size[0]]
)
feed_dict[self.policy.model.epsilon] = mini_batch[
"random_normal_epsilon"
].reshape([-1, self.policy.model.act_size[0]])
else:
feed_dict[self.policy.model.action_holder] = mini_batch["actions"].reshape(
[-1, len(self.policy.model.act_size)]
)
if self.policy.use_recurrent:
feed_dict[self.policy.model.prev_action] = mini_batch[
"prev_action"
].reshape([-1, len(self.policy.model.act_size)])
feed_dict[self.policy.model.action_masks] = mini_batch[
"action_mask"
].reshape([-1, sum(self.policy.brain.vector_action_space_size)])
if self.policy.use_vec_obs:
feed_dict[self.policy.model.vector_in] = mini_batch["vector_obs"].reshape(
[-1, self.policy.vec_obs_size]
)
feed_dict[self.model.next_vector_in] = mini_batch["next_vector_in"].reshape(
[-1, self.policy.vec_obs_size]
)
if self.policy.model.vis_obs_size > 0:
for i, _ in enumerate(self.policy.model.visual_in):
_obs = mini_batch["visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.policy.model.visual_in[i]] = _obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.policy.model.visual_in[i]] = _obs
for i, _ in enumerate(self.policy.model.visual_in):
_obs = mini_batch["next_visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.next_visual_in[i]] = _obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.model.next_visual_in[i]] = _obs
if self.policy.use_recurrent:
mem_in = mini_batch["memory"][:, 0, :]
feed_dict[self.policy.model.memory_in] = mem_in
self.has_updated = True
run_out = self.policy._execute_model(feed_dict, self.update_dict)
return run_out

1
ml-agents/mlagents/trainers/components/reward_signals/extrinsic/__init__.py


from .signal import ExtrinsicRewardSignal

42
ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py


import numpy as np
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.policy import Policy
class ExtrinsicRewardSignal(RewardSignal):
def __init__(self, policy: Policy, strength: float, gamma: float):
"""
The extrinsic reward generator. Returns the reward received by the environment
:param policy: The Policy object (e.g. PPOPolicy) that this Reward Signal will apply to.
:param strength: The strength of the reward. The reward's raw value will be multiplied by this value.
:param gamma: The time discounting factor used for this reward.
:return: An ExtrinsicRewardSignal object.
"""
super().__init__(policy, strength, gamma)
@classmethod
def check_config(cls, config_dict):
"""
Checks the config and throw an exception if a hyperparameter is missing. Extrinsic requires strength and gamma
at minimum.
"""
param_keys = ["strength", "gamma"]
super().check_config(config_dict, param_keys)
def evaluate(self, current_info, next_info):
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.
:param next_info: The BrainInfo from the next timestep.
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator
"""
unscaled_reward = np.array(next_info.rewards)
scaled_reward = self.strength * unscaled_reward
return RewardSignalResult(scaled_reward, unscaled_reward)
def update(self, update_buffer, num_sequences):
"""
This method does nothing, as there is nothing to update.
"""
return {}

68
ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py


import logging
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.policy import Policy
from collections import namedtuple
import numpy as np
import abc
import tensorflow as tf
logger = logging.getLogger("mlagents.trainers")
RewardSignalResult = namedtuple(
"RewardSignalResult", ["scaled_reward", "unscaled_reward"]
)
class RewardSignal(abc.ABC):
def __init__(self, policy: Policy, strength: float, gamma: float):
"""
Initializes a reward signal. At minimum, you must pass in the policy it is being applied to,
the reward strength, and the gamma (discount factor.)
:param policy: The Policy object (e.g. PPOPolicy) that this Reward Signal will apply to.
:param strength: The strength of the reward. The reward's raw value will be multiplied by this value.
:param gamma: The time discounting factor used for this reward.
:return: A RewardSignal object.
"""
class_name = self.__class__.__name__
short_name = class_name.replace("RewardSignal", "")
self.stat_name = f"Policy/{short_name} Reward"
self.value_name = f"Policy/{short_name} Value Estimate"
self.gamma = gamma
self.policy = policy
self.strength = strength
def evaluate(self, current_info, next_info):
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.
:param next_info: The BrainInfo from the next timestep.
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator
"""
return (
self.strength * np.zeros(len(current_info.agents)),
np.zeros(len(current_info.agents)),
)
def update(self, update_buffer, n_sequences):
"""
If the reward signal has an internal model (e.g. GAIL or Curiosity), update that model.
:param update_buffer: An AgentBuffer that contains the live data from which to update.
:param n_sequences: The number of sequences in the training buffer.
:return: A dict of {"Stat Name": stat} to be added to Tensorboard
"""
return {}
@classmethod
def check_config(cls, config_dict, param_keys=None):
"""
Check the config dict, and throw an error if there are missing hyperparameters.
"""
param_keys = param_keys or []
for k in param_keys:
if k not in config_dict:
raise UnityTrainerException(
"The hyper-parameter {0} could not be found for {1}.".format(
k, cls.__name__
)
)

43
ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py


import logging
from typing import Any, Dict, Type
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.components.reward_signals.reward_signal import RewardSignal
from mlagents.trainers.components.reward_signals.extrinsic.signal import (
ExtrinsicRewardSignal,
)
from mlagents.trainers.components.reward_signals.curiosity.signal import (
CuriosityRewardSignal,
)
from mlagents.trainers.policy import Policy
logger = logging.getLogger("mlagents.trainers")
NAME_TO_CLASS: Dict[str, Type[RewardSignal]] = {
"extrinsic": ExtrinsicRewardSignal,
"curiosity": CuriosityRewardSignal,
}
def create_reward_signal(
policy: Policy, name: str, config_entry: Dict[str, Any]
) -> RewardSignal:
"""
Creates a reward signal class based on the name and config entry provided as a dict.
:param policy: The policy class which the reward will be applied to.
:param name: The name of the reward signal
:param config_entry: The config entries for that reward signal
:return: The reward signal class instantiated
"""
rcls = NAME_TO_CLASS.get(name)
if not rcls:
raise UnityTrainerException("Unknown reward signal type {0}".format(name))
rcls.check_config(config_entry)
try:
class_inst = rcls(policy, **config_entry)
except TypeError:
raise UnityTrainerException(
"Unknown parameters given for reward signal {0}".format(name)
)
return class_inst
正在加载...
取消
保存