浏览代码

Refactor Trainer and Model (#2360)

- Move common functions to trainer.py, model.pyfromppo/trainer.py, ppo/policy.pyandppo/model.py'
- Introduce RLTrainer class and move most of add_experiences and some common reward
signal code there. PPO and SAC will inherit from this, not so much BC Trainer.
- Add methods to Buffer to enable sampling, truncating, and save/loading.
- Add scoping to create encoders in model.py
/develop-gpu-test
GitHub 5 年前
当前提交
7b69bd14
共有 16 个文件被更改,包括 836 次插入653 次删除
  1. 2
      ml-agents/mlagents/trainers/bc/models.py
  2. 11
      ml-agents/mlagents/trainers/bc/offline_trainer.py
  3. 11
      ml-agents/mlagents/trainers/bc/online_trainer.py
  4. 30
      ml-agents/mlagents/trainers/bc/trainer.py
  5. 61
      ml-agents/mlagents/trainers/buffer.py
  6. 5
      ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
  7. 260
      ml-agents/mlagents/trainers/models.py
  8. 229
      ml-agents/mlagents/trainers/ppo/models.py
  9. 26
      ml-agents/mlagents/trainers/ppo/policy.py
  10. 301
      ml-agents/mlagents/trainers/ppo/trainer.py
  11. 4
      ml-agents/mlagents/trainers/tests/mock_brain.py
  12. 43
      ml-agents/mlagents/trainers/tests/test_buffer.py
  13. 171
      ml-agents/mlagents/trainers/trainer.py
  14. 1
      ml-agents/setup.py
  15. 253
      ml-agents/mlagents/trainers/rl_trainer.py
  16. 81
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py

2
ml-agents/mlagents/trainers/bc/models.py


self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
self.sample_action_float, normalized_logits = self.create_discrete_action_masking_layer(
self.sample_action_float, _, normalized_logits = self.create_discrete_action_masking_layer(
tf.concat(policy_branches, axis=1), self.action_masks, self.act_size
)
tf.identity(normalized_logits, name="action")

11
ml-agents/mlagents/trainers/bc/offline_trainer.py


"The provided demonstration is not compatible with the "
"brain being used for performance evaluation."
)
def __str__(self):
return """Hyperparameters for the Imitation Trainer of brain {0}: \n{1}""".format(
self.brain_name,
"\n".join(
[
"\t{0}:\t{1}".format(x, self.trainer_parameters[x])
for x in self.param_keys
]
),
)

11
ml-agents/mlagents/trainers/bc/online_trainer.py


int(trainer_parameters["batch_size"] / self.policy.sequence_length), 1
)
def __str__(self):
return """Hyperparameters for the Imitation Trainer of brain {0}: \n{1}""".format(
self.brain_name,
"\n".join(
[
"\t{0}:\t{1}".format(x, self.trainer_parameters[x])
for x in self.param_keys
]
),
)
def add_experiences(
self,
curr_info: AllBrainInfo,

30
ml-agents/mlagents/trainers/bc/trainer.py


self.demonstration_buffer = Buffer()
self.evaluation_buffer = Buffer()
@property
def parameters(self):
"""
Returns the trainer parameters of the trainer.
"""
return self.trainer_parameters
@property
def get_max_steps(self):
"""
Returns the maximum number of steps. Is used to know when the trainer should be stopped.
:return: The maximum number of steps of the trainer
"""
return float(self.trainer_parameters["max_steps"])
@property
def get_step(self):
"""
Returns the number of steps the trainer has performed
:return: the step count of the trainer
"""
return self.policy.get_current_step()
def increment_step(self):
"""
Increment the step count of the trainer
"""
self.policy.increment_step()
return
def add_experiences(
self,
curr_info: AllBrainInfo,

61
ml-agents/mlagents/trainers/buffer.py


import random
from collections import defaultdict
import h5py
from mlagents.envs.exception import UnityException

mini_batch[key] = self[key][start:end]
return mini_batch
def sample_mini_batch(self, batch_size, sequence_length=1):
"""
Creates a mini-batch from a random start and end.
:param batch_size: number of elements to withdraw.
:param sequence_length: Length of sequences to sample.
Number of sequences to sample will be batch_size/sequence_length.
"""
num_seq_to_sample = batch_size // sequence_length
mini_batch = Buffer.AgentBuffer()
buff_len = len(next(iter(self.values())))
num_sequences_in_buffer = buff_len // sequence_length
start_idxes = [
random.randint(0, num_sequences_in_buffer - 1) * sequence_length
for _ in range(num_seq_to_sample)
] # Sample random sequence starts
for i in start_idxes:
for key in self:
mini_batch[key].extend(self[key][i : i + sequence_length])
return mini_batch
def save_to_file(self, file_object):
"""
Saves the AgentBuffer to a file-like object.
"""
with h5py.File(file_object) as write_file:
for key, data in self.items():
write_file.create_dataset(
key, data=data, dtype="f", compression="gzip"
)
def load_from_file(self, file_object):
"""
Loads the AgentBuffer from a file-like object.
"""
with h5py.File(file_object) as read_file:
for key in list(read_file.keys()):
self[key] = Buffer.AgentBuffer.AgentBufferField()
# extend() will convert the numpy array's first dimension into list
self[key].extend(read_file[key][()])
def __init__(self):
self.update_buffer = self.AgentBuffer()
super(Buffer, self).__init__()

Resets the update buffer
"""
self.update_buffer.reset_agent()
def truncate_update_buffer(self, max_length, sequence_length=1):
"""
Truncates the update buffer to a certain length.
This can be slow for large buffers. We compensate by cutting further than we need to, so that
we're not truncating at each update. Note that we must truncate an integer number of sequence_lengths
param: max_length: The length at which to truncate the buffer.
"""
current_length = len(next(iter(self.update_buffer.values())))
# make max_length an integer number of sequence_lengths
max_length -= max_length % sequence_length
if current_length > max_length:
for _key in self.update_buffer.keys():
self.update_buffer[_key] = self.update_buffer[_key][
current_length - max_length :
]
def reset_local_buffers(self):
"""

5
ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py


param_keys = ["strength", "gamma"]
super().check_config(config_dict, param_keys)
def evaluate_batch(self, mini_batch: Dict[str, np.array]) -> RewardSignalResult:
env_rews = mini_batch["environment_rewards"]
return RewardSignalResult(self.strength * env_rews, env_rews)
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:

260
ml-agents/mlagents/trainers/models.py


import logging
from enum import Enum
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List
import numpy as np
import tensorflow as tf

ActivationFunction = Callable[[tf.Tensor], tf.Tensor]
EPSILON = 1e-7
class EncoderType(Enum):

:param all_logits: The concatenated unnormalized action probabilities for all branches
:param action_masks: The mask for the logits. Must be of dimension [None x total_number_of_action]
:param action_size: A list containing the number of possible actions for each branch
:return: The action output dimension [batch_size, num_branches] and the concatenated normalized logits
:return: The action output dimension [batch_size, num_branches], the concatenated
normalized probs (after softmax)
and the concatenated normalized log probs
"""
action_idx = [0] + list(np.cumsum(action_size))
branches_logits = [

for i in range(len(action_size))
]
raw_probs = [
tf.multiply(tf.nn.softmax(branches_logits[k]) + 1.0e-10, branch_masks[k])
tf.multiply(tf.nn.softmax(branches_logits[k]) + EPSILON, branch_masks[k])
for k in range(len(action_size))
]
normalized_probs = [

output = tf.concat(
[
tf.multinomial(tf.log(normalized_probs[k]), 1)
tf.multinomial(tf.log(normalized_probs[k] + EPSILON), 1)
for k in range(len(action_size))
],
axis=1,

tf.concat([normalized_probs[k] for k in range(len(action_size))], axis=1),
tf.log(normalized_probs[k] + 1.0e-10)
tf.log(normalized_probs[k] + EPSILON)
for k in range(len(action_size))
],
axis=1,

h_size: int,
num_layers: int,
vis_encode_type: EncoderType = EncoderType.SIMPLE,
stream_scopes: List[str] = None,
) -> tf.Tensor:
"""
Creates encoding stream for observations.

:param stream_scopes: List of strings (length == num_streams), which contains
the scopes for each of the streams. None if all under the same TF scope.
:return: List of encoded streams.
"""
brain = self.brain

for i in range(num_streams):
visual_encoders = []
hidden_state, hidden_visual = None, None
_scope_add = stream_scopes[i] if stream_scopes else ""
if self.vis_obs_size > 0:
if vis_encode_type == EncoderType.RESNET:
for j in range(brain.number_visual_observations):

activation_fn,
num_layers,
"main_graph_{}_encoder{}".format(i, j),
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)

h_size,
activation_fn,
num_layers,
"main_graph_{}_encoder{}".format(i, j),
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)

h_size,
activation_fn,
num_layers,
"main_graph_{}_encoder{}".format(i, j),
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)

h_size,
activation_fn,
num_layers,
"main_graph_{}".format(i),
_scope_add + "main_graph_{}".format(i),
False,
)
if hidden_state is not None and hidden_visual is not None:

value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name))
self.value_heads[name] = value
self.value = tf.reduce_mean(list(self.value_heads.values()), 0)
def create_cc_actor_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
2, h_size, num_layers, vis_encode_type
)
if self.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_streams[0],
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_streams[1],
self.memory_in[:, _half_point:],
self.sequence_length,
name="lstm_value",
)
self.memory_out = tf.concat(
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out"
)
else:
hidden_policy = hidden_streams[0]
hidden_value = hidden_streams[1]
mu = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
name="mu",
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01),
)
self.log_sigma_sq = tf.get_variable(
"log_sigma_squared",
[self.act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
sigma_sq = tf.exp(self.log_sigma_sq)
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"
)
# Clip and scale output to ensure actions are always within [-1, 1] range.
self.output_pre = mu + tf.sqrt(sigma_sq) * self.epsilon
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")
self.selected_actions = tf.stop_gradient(output_post)
# Compute probability of model output.
all_probs = (
-0.5 * tf.square(tf.stop_gradient(self.output_pre) - mu) / sigma_sq
- 0.5 * tf.log(2.0 * np.pi)
- 0.5 * self.log_sigma_sq
)
self.all_log_probs = tf.identity(all_probs, name="action_probs")
self.entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + self.log_sigma_sq
)
self.create_value_heads(self.stream_names, hidden_value)
self.all_old_log_probs = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="old_probabilities"
)
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
self.log_probs = tf.reduce_sum(
(tf.identity(self.all_log_probs)), axis=1, keepdims=True
)
self.old_log_probs = tf.reduce_sum(
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True
)
def create_dc_actor_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
1, h_size, num_layers, vis_encode_type
)
hidden = hidden_streams[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
hidden = tf.concat([hidden, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden, memory_out = self.create_recurrent_encoder(
hidden, self.memory_in, self.sequence_length
)
self.memory_out = tf.identity(memory_out, name="recurrent_out")
policy_branches = []
for i, size in enumerate(self.act_size):
policy_branches.append(
tf.layers.dense(
hidden,
size,
activation=None,
use_bias=False,
name="policy_branch_" + str(i),
kernel_initializer=c_layers.variance_scaling_initializer(
factor=0.01
),
)
)
self.all_log_probs = tf.concat(
[branch for branch in policy_branches], axis=1, name="action_probs"
)
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
output, normalized_logits = self.create_discrete_action_masking_layer(
self.all_log_probs, self.action_masks, self.act_size
)
self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
self.create_value_heads(self.stream_names, hidden)
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"
)
self.action_oh = tf.concat(
[
tf.one_hot(self.action_holder[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
self.selected_actions = tf.stop_gradient(self.action_oh)
self.all_old_log_probs = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="old_probabilities"
)
_, old_normalized_logits = self.create_discrete_action_masking_layer(
self.all_old_log_probs, self.action_masks, self.act_size
)
action_idx = [0] + list(np.cumsum(self.act_size))
self.entropy = tf.reduce_sum(
(
tf.stack(
[
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.nn.softmax(
self.all_log_probs[:, action_idx[i] : action_idx[i + 1]]
),
logits=self.all_log_probs[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
)
self.log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
self.old_log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=old_normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)

229
ml-agents/mlagents/trainers/ppo/models.py


max_step,
)
def create_cc_actor_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
2, h_size, num_layers, vis_encode_type
)
if self.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_streams[0],
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_streams[1],
self.memory_in[:, _half_point:],
self.sequence_length,
name="lstm_value",
)
self.memory_out = tf.concat(
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out"
)
else:
hidden_policy = hidden_streams[0]
hidden_value = hidden_streams[1]
mu = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
kernel_initializer=LearningModel.scaled_init(0.01),
)
self.log_sigma_sq = tf.get_variable(
"log_sigma_squared",
[self.act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
sigma_sq = tf.exp(self.log_sigma_sq)
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"
)
# Clip and scale output to ensure actions are always within [-1, 1] range.
self.output_pre = mu + tf.sqrt(sigma_sq) * self.epsilon
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")
self.selected_actions = tf.stop_gradient(output_post)
# Compute probability of model output.
all_probs = (
-0.5 * tf.square(tf.stop_gradient(self.output_pre) - mu) / sigma_sq
- 0.5 * tf.log(2.0 * np.pi)
- 0.5 * self.log_sigma_sq
)
self.all_log_probs = tf.identity(all_probs, name="action_probs")
self.entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + self.log_sigma_sq
)
self.create_value_heads(self.stream_names, hidden_value)
self.all_old_log_probs = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="old_probabilities"
)
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
self.log_probs = tf.reduce_sum(
(tf.identity(self.all_log_probs)), axis=1, keepdims=True
)
self.old_log_probs = tf.reduce_sum(
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True
)
def create_dc_actor_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
1, h_size, num_layers, vis_encode_type
)
hidden = hidden_streams[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
hidden = tf.concat([hidden, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden, memory_out = self.create_recurrent_encoder(
hidden, self.memory_in, self.sequence_length
)
self.memory_out = tf.identity(memory_out, name="recurrent_out")
policy_branches = []
for size in self.act_size:
policy_branches.append(
tf.layers.dense(
hidden,
size,
activation=None,
use_bias=False,
kernel_initializer=LearningModel.scaled_init(0.01),
)
)
self.all_log_probs = tf.concat(
[branch for branch in policy_branches], axis=1, name="action_probs"
)
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
output, _, normalized_logits = self.create_discrete_action_masking_layer(
self.all_log_probs, self.action_masks, self.act_size
)
self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
self.create_value_heads(self.stream_names, hidden)
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"
)
self.action_oh = tf.concat(
[
tf.one_hot(self.action_holder[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
self.selected_actions = tf.stop_gradient(self.action_oh)
self.all_old_log_probs = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="old_probabilities"
)
_, _, old_normalized_logits = self.create_discrete_action_masking_layer(
self.all_old_log_probs, self.action_masks, self.act_size
)
action_idx = [0] + list(np.cumsum(self.act_size))
self.entropy = tf.reduce_sum(
(
tf.stack(
[
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.nn.softmax(
self.all_log_probs[:, action_idx[i] : action_idx[i + 1]]
),
logits=self.all_log_probs[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
)
self.log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
self.old_log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=old_normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
def create_losses(
self, probs, old_probs, value_heads, entropy, beta, epsilon, lr, max_step
):

26
ml-agents/mlagents/trainers/ppo/policy.py


self.inference_dict = {
"action": self.model.output,
"log_probs": self.model.all_log_probs,
"value": self.model.value_heads,
"value": self.model.value,
"value_heads": self.model.value_heads,
"entropy": self.model.entropy,
"learning_rate": self.model.learning_rate,
}

value_estimates[k] = 0.0
return value_estimates
def get_action(self, brain_info: BrainInfo) -> ActionInfo:
"""
Decides actions given observations information, and takes them in environment.
:param brain_info: A dictionary of brain names and BrainInfo from environment.
:return: an ActionInfo containing action, memories, values and an object
to be passed to add experiences
"""
if len(brain_info.agents) == 0:
return ActionInfo([], [], [], None, None)
run_out = self.evaluate(brain_info)
mean_values = np.mean(
np.array(list(run_out.get("value").values())), axis=0
).flatten()
return ActionInfo(
action=run_out.get("action"),
memory=run_out.get("memory_out"),
text=None,
value=mean_values,
outputs=run_out,
)

301
ml-agents/mlagents/trainers/ppo/trainer.py


import logging
from collections import defaultdict
from typing import List, Any
from typing import List, Any, Dict
import numpy as np

from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy, get_devices
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.rl_trainer import RLTrainer
from mlagents.trainers.components.reward_signals import RewardSignalResult
class PPOTrainer(Trainer):
class PPOTrainer(RLTrainer):
"""The PPOTrainer is an implementation of the PPO algorithm."""
def __init__(

:param seed: The seed the model will be initialized with
:param run_id: The identifier of the current run
"""
super().__init__(brain, trainer_parameters, training, run_id, reward_buff_cap)
super(PPOTrainer, self).__init__(
brain, trainer_parameters, training, run_id, reward_buff_cap
)
self.param_keys = [
"batch_size",
"beta",

]
self.check_param_keys()
# Make sure we have at least one reward_signal
if not self.trainer_parameters["reward_signals"]:
raise UnityTrainerException(
"No reward signals were defined. At least one must be used with {}.".format(
self.__class__.__name__
)
)
self.step = 0
if multi_gpu and len(get_devices()) > 1:
self.policy = MultiGpuPPOPolicy(
seed, brain, trainer_parameters, self.is_training, load

seed, brain, trainer_parameters, self.is_training, load
)
stats = defaultdict(list)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.collected_rewards = {"environment": {}}
self.stats = stats
self.training_buffer = Buffer()
self.episode_steps = {}
def __str__(self):
return """Hyperparameters for the {0} of brain {1}: \n{2}""".format(
self.__class__.__name__,
self.brain_name,
self.dict_to_str(self.trainer_parameters, 0),
)
@property
def parameters(self):
"""
Returns the trainer parameters of the trainer.
"""
return self.trainer_parameters
@property
def get_max_steps(self):
"""
Returns the maximum number of steps. Is used to know when the trainer should be stopped.
:return: The maximum number of steps of the trainer
"""
return float(self.trainer_parameters["max_steps"])
@property
def get_step(self):
"""
Returns the number of steps the trainer has performed
:return: the step count of the trainer
"""
return self.step
def increment_step(self, n_steps: int) -> None:
"""
Increment the step count of the trainer
:param n_steps: number of steps to increment the step count by
"""
self.step = self.policy.increment_step(n_steps)
def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
"""
Constructs a BrainInfo which contains the most recent previous experiences for all agents
which correspond to the agents in a provided next_info.
:BrainInfo next_info: A t+1 BrainInfo.
:return: curr_info: Reconstructed BrainInfo to match agents of next_info.
"""
visual_observations: List[List[Any]] = [
[]
] # TODO add types to brain.py methods
vector_observations = []
text_observations = []
memories = []
rewards = []
local_dones = []
max_reacheds = []
agents = []
prev_vector_actions = []
prev_text_actions = []
action_masks = []
for agent_id in next_info.agents:
agent_brain_info = self.training_buffer[agent_id].last_brain_info
if agent_brain_info is None:
agent_brain_info = next_info
agent_index = agent_brain_info.agents.index(agent_id)
for i in range(len(next_info.visual_observations)):
visual_observations[i].append(
agent_brain_info.visual_observations[i][agent_index]
)
vector_observations.append(
agent_brain_info.vector_observations[agent_index]
)
text_observations.append(agent_brain_info.text_observations[agent_index])
if self.policy.use_recurrent:
if len(agent_brain_info.memories) > 0:
memories.append(agent_brain_info.memories[agent_index])
else:
memories.append(self.policy.make_empty_memory(1))
rewards.append(agent_brain_info.rewards[agent_index])
local_dones.append(agent_brain_info.local_done[agent_index])
max_reacheds.append(agent_brain_info.max_reached[agent_index])
agents.append(agent_brain_info.agents[agent_index])
prev_vector_actions.append(
agent_brain_info.previous_vector_actions[agent_index]
)
prev_text_actions.append(
agent_brain_info.previous_text_actions[agent_index]
)
action_masks.append(agent_brain_info.action_masks[agent_index])
if self.policy.use_recurrent:
memories = np.vstack(memories)
curr_info = BrainInfo(
visual_observations,
vector_observations,
text_observations,
memories,
rewards,
agents,
local_dones,
prev_vector_actions,
prev_text_actions,
max_reacheds,
action_masks,
)
return curr_info
def add_experiences(
self,
curr_all_info: AllBrainInfo,
next_all_info: AllBrainInfo,
take_action_outputs: ActionInfoOutputs,
) -> None:
"""
Adds experiences to each agent's experience history.
:param curr_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param next_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param take_action_outputs: The outputs of the Policy's get_action method.
"""
self.trainer_metrics.start_experience_collection_timer()
if take_action_outputs:
self.stats["Policy/Entropy"].append(take_action_outputs["entropy"].mean())
self.stats["Policy/Learning Rate"].append(
take_action_outputs["learning_rate"]
)
for name, signal in self.policy.reward_signals.items():
self.stats[signal.value_name].append(
np.mean(take_action_outputs["value"][name])
)
curr_info = curr_all_info[self.brain_name]
next_info = next_all_info[self.brain_name]
for agent_id in curr_info.agents:
self.training_buffer[agent_id].last_brain_info = curr_info
self.training_buffer[
agent_id
].last_take_action_outputs = take_action_outputs
if curr_info.agents != next_info.agents:
curr_to_use = self.construct_curr_info(next_info)
else:
curr_to_use = curr_info
tmp_rewards_dict = {}
for name, signal in self.policy.reward_signals.items():
tmp_rewards_dict[name] = signal.evaluate(curr_to_use, next_info)
for agent_id in next_info.agents:
stored_info = self.training_buffer[agent_id].last_brain_info
stored_take_action_outputs = self.training_buffer[
agent_id
].last_take_action_outputs
if stored_info is not None:
idx = stored_info.agents.index(agent_id)
next_idx = next_info.agents.index(agent_id)
if not stored_info.local_done[idx]:
for i, _ in enumerate(stored_info.visual_observations):
self.training_buffer[agent_id]["visual_obs%d" % i].append(
stored_info.visual_observations[i][idx]
)
self.training_buffer[agent_id]["next_visual_obs%d" % i].append(
next_info.visual_observations[i][next_idx]
)
if self.policy.use_vec_obs:
self.training_buffer[agent_id]["vector_obs"].append(
stored_info.vector_observations[idx]
)
self.training_buffer[agent_id]["next_vector_in"].append(
next_info.vector_observations[next_idx]
)
if self.policy.use_recurrent:
if stored_info.memories.shape[1] == 0:
stored_info.memories = np.zeros(
(len(stored_info.agents), self.policy.m_size)
)
self.training_buffer[agent_id]["memory"].append(
stored_info.memories[idx]
)
actions = stored_take_action_outputs["action"]
if self.policy.use_continuous_act:
actions_pre = stored_take_action_outputs["pre_action"]
self.training_buffer[agent_id]["actions_pre"].append(
actions_pre[idx]
)
epsilons = stored_take_action_outputs["random_normal_epsilon"]
self.training_buffer[agent_id]["random_normal_epsilon"].append(
epsilons[idx]
)
else:
self.training_buffer[agent_id]["action_mask"].append(
stored_info.action_masks[idx], padding_value=1
)
a_dist = stored_take_action_outputs["log_probs"]
# value is a dictionary from name of reward to value estimate of the value head
value = stored_take_action_outputs["value"]
self.training_buffer[agent_id]["actions"].append(actions[idx])
self.training_buffer[agent_id]["prev_action"].append(
stored_info.previous_vector_actions[idx]
)
self.training_buffer[agent_id]["masks"].append(1.0)
self.training_buffer[agent_id]["done"].append(
next_info.local_done[next_idx]
)
for name, reward_result in tmp_rewards_dict.items():
# 0 because we use the scaled reward to train the agent
self.training_buffer[agent_id][
"{}_rewards".format(name)
].append(reward_result.scaled_reward[next_idx])
self.training_buffer[agent_id][
"{}_value_estimates".format(name)
].append(value[name][idx][0])
self.training_buffer[agent_id]["action_probs"].append(a_dist[idx])
for name, rewards in self.collected_rewards.items():
if agent_id not in rewards:
rewards[agent_id] = 0
if name == "environment":
# Report the reward from the environment
rewards[agent_id] += np.array(next_info.rewards)[next_idx]
else:
# Report the reward signals
rewards[agent_id] += tmp_rewards_dict[name].scaled_reward[
next_idx
]
if not next_info.local_done[next_idx]:
if agent_id not in self.episode_steps:
self.episode_steps[agent_id] = 0
self.episode_steps[agent_id] += 1
self.trainer_metrics.end_experience_collection_timer()
def process_experiences(
self, current_info: AllBrainInfo, new_info: AllBrainInfo
) -> None:

self.policy.reward_signals[name].stat_name
].append(rewards.get(agent_id, 0))
rewards[agent_id] = 0
def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int
) -> None:
"""
Takes the output of the last action and store it into the training buffer.
"""
actions = take_action_outputs["action"]
if self.policy.use_continuous_act:
actions_pre = take_action_outputs["pre_action"]
self.training_buffer[agent_id]["actions_pre"].append(actions_pre[agent_idx])
epsilons = take_action_outputs["random_normal_epsilon"]
self.training_buffer[agent_id]["random_normal_epsilon"].append(
epsilons[agent_idx]
)
a_dist = take_action_outputs["log_probs"]
# value is a dictionary from name of reward to value estimate of the value head
self.training_buffer[agent_id]["actions"].append(actions[agent_idx])
self.training_buffer[agent_id]["action_probs"].append(a_dist[agent_idx])
def add_rewards_outputs(
self,
value: Dict[str, Any],
rewards_dict: Dict[str, RewardSignalResult],
agent_id: str,
agent_idx: int,
agent_next_idx: int,
) -> None:
"""
Takes the value output of the last action and store it into the training buffer.
"""
for name, reward_result in rewards_dict.items():
# 0 because we use the scaled reward to train the agent
self.training_buffer[agent_id]["{}_rewards".format(name)].append(
reward_result.scaled_reward[agent_idx]
)
self.training_buffer[agent_id]["{}_value_estimates".format(name)].append(
value[name][agent_next_idx][0]
)
def end_episode(self):
"""

4
ml-agents/mlagents/trainers/tests/mock_brain.py


camrez = {"blackAndWhite": False, "height": 84, "width": 84}
mock_brain.return_value.camera_resolutions = [camrez] * number_visual_observations
mock_brain.return_value.vector_action_space_size = vector_action_space_size
mock_brain.return_value.brain_name = "MockBrain"
return mock_brain()

mock_braininfo.return_value.rewards = num_agents * [1.0]
mock_braininfo.return_value.local_done = num_agents * [False]
mock_braininfo.return_value.text_observations = num_agents * [""]
mock_braininfo.return_value.previous_text_actions = num_agents * [""]
mock_braininfo.return_value.max_reached = num_agents * [100]
mock_braininfo.return_value.action_masks = num_agents * [num_vector_acts * [1.0]]
mock_braininfo.return_value.agents = range(0, num_agents)
return mock_braininfo()

43
ml-agents/mlagents/trainers/tests/test_buffer.py


assert la[i] == lb[i]
def test_buffer():
def construct_fake_buffer():
b = Buffer()
for fake_agent_id in range(4):
for step in range(9):

100 * fake_agent_id + 10 * step + 5,
]
)
return b
def test_buffer():
b = construct_fake_buffer()
a = b[1]["vector_observation"].get_batch(
batch_size=2, training_length=1, sequential=True
)

c = b.update_buffer.make_mini_batch(start=0, end=1)
assert c.keys() == b.update_buffer.keys()
assert np.array(c["action"]).shape == (1, 2)
def fakerandint(values):
return 19
def test_buffer_sample():
b = construct_fake_buffer()
b.append_update_buffer(3, batch_size=None, training_length=2)
b.append_update_buffer(2, batch_size=None, training_length=2)
# Test non-LSTM
mb = b.update_buffer.sample_mini_batch(batch_size=4, sequence_length=1)
assert mb.keys() == b.update_buffer.keys()
assert np.array(mb["action"]).shape == (4, 2)
# Test LSTM
# We need to check if we ever get a breaking start - this will maximize the probability
mb = b.update_buffer.sample_mini_batch(batch_size=20, sequence_length=19)
assert mb.keys() == b.update_buffer.keys()
# Should only return one sequence
assert np.array(mb["action"]).shape == (19, 2)
def test_buffer_truncate():
b = construct_fake_buffer()
b.append_update_buffer(3, batch_size=None, training_length=2)
b.append_update_buffer(2, batch_size=None, training_length=2)
# Test non-LSTM
b.truncate_update_buffer(2)
assert len(b.update_buffer["action"]) == 2
b.append_update_buffer(3, batch_size=None, training_length=2)
b.append_update_buffer(2, batch_size=None, training_length=2)
# Test LSTM, truncate should be some multiple of sequence_length
b.truncate_update_buffer(4, sequence_length=3)
assert len(b.update_buffer["action"]) == 3

171
ml-agents/mlagents/trainers/trainer.py


# # Unity ML-Agents Toolkit
import logging
from typing import Dict, List, Deque, Any
from collections import deque
from collections import deque, defaultdict
from mlagents.envs import UnityException, AllBrainInfo, ActionInfoOutputs
from mlagents.envs import UnityException, AllBrainInfo, ActionInfoOutputs, BrainInfo
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.tf_policy import Policy
from mlagents.envs import BrainParameters
LOGGER = logging.getLogger("mlagents.trainers")

class Trainer(object):
"""This class is the base class for the mlagents.envs.trainers"""
def __init__(self, brain, trainer_parameters, training, run_id, reward_buff_cap=1):
def __init__(
self,
brain: BrainParameters,
trainer_parameters: dict,
training: bool,
run_id: int,
reward_buff_cap: int = 1,
):
"""
Responsible for collecting experiences and training a neural network model.
:BrainParameters brain: Brain to be trained.

:int reward_buff_cap:
self.param_keys = []
self.param_keys: List[str] = []
self.brain_name = brain.brain_name
self.run_id = run_id
self.trainer_parameters = trainer_parameters

self.cumulative_returns_since_policy_update = []
self.cumulative_returns_since_policy_update: List[float] = []
self.stats = {}
self.stats: Dict[str, List] = defaultdict(list)
self.policy = None
self._reward_buffer = deque(maxlen=reward_buff_cap)
def __str__(self):
return """{} Trainer""".format(self.__class__)
self._reward_buffer: Deque[float] = deque(maxlen=reward_buff_cap)
self.policy: Policy = None
def check_param_keys(self):
for k in self.param_keys:

"brain {2}.".format(k, self.__class__, self.brain_name)
)
def dict_to_str(self, param_dict, num_tabs):
def dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str:
"""
Takes a parameter dictionary and converts it to a human-readable string.
Recurses if there are multiple levels of dict. Used to print out hyperaparameters.

if not isinstance(param_dict, dict):
return param_dict
return str(param_dict)
else:
append_newline = "\n" if num_tabs > 0 else ""
return append_newline + "\n".join(

]
)
@property
def parameters(self):
"""
Returns the trainer parameters of the trainer.
"""
raise UnityTrainerException("The parameters property was not implemented.")
def __str__(self) -> str:
return """Hyperparameters for the {0} of brain {1}: \n{2}""".format(
self.__class__.__name__,
self.brain_name,
self.dict_to_str(self.trainer_parameters, 0),
)
def graph_scope(self):
def parameters(self) -> Dict[str, Any]:
Returns the graph scope of the trainer.
Returns the trainer parameters of the trainer.
raise UnityTrainerException("The graph_scope property was not implemented.")
return self.trainer_parameters
def get_max_steps(self):
def get_max_steps(self) -> float:
raise UnityTrainerException("The get_max_steps property was not implemented.")
return float(self.trainer_parameters["max_steps"])
def get_step(self):
def get_step(self) -> int:
Returns the number of training steps the trainer has performed
Returns the number of steps the trainer has performed
raise UnityTrainerException("The get_step property was not implemented.")
return self.step
def reward_buffer(self):
def reward_buffer(self) -> Deque[float]:
"""
Returns the reward buffer. The reward buffer contains the cumulative
rewards of the most recent episodes completed by agents using this

def increment_step(self, n_steps: int) -> None:
"""
Increment the step count of the trainer
"""
raise UnityTrainerException("The increment_step method was not implemented.")
def add_experiences(
self,
curr_info: AllBrainInfo,
next_info: AllBrainInfo,
take_action_outputs: ActionInfoOutputs,
) -> None:
:param n_steps: number of steps to increment the step count by
Adds experiences to each agent's experience history.
:param curr_info: Current AllBrainInfo.
:param next_info: Next AllBrainInfo.
:param take_action_outputs: The outputs of the take action method.
"""
raise UnityTrainerException("The add_experiences method was not implemented.")
self.step = self.policy.increment_step(n_steps)
def process_experiences(
self, current_info: AllBrainInfo, next_info: AllBrainInfo
) -> None:
"""
Checks agent histories for processing condition, and processes them as necessary.
Processing involves calculating value and advantage targets for model updating step.
:param current_info: Dictionary of all current-step brains and corresponding BrainInfo.
:param next_info: Dictionary of all next-step brains and corresponding BrainInfo.
"""
raise UnityTrainerException(
"The process_experiences method was not implemented."
)
def end_episode(self):
"""
A signal that the Episode has ended. The buffer must be reset.
Get only called when the academy resets.
"""
raise UnityTrainerException("The end_episode method was not implemented.")
def is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
raise UnityTrainerException("The is_ready_update method was not implemented.")
def update_policy(self):
"""
Uses demonstration_buffer to update model.
"""
raise UnityTrainerException("The update_model method was not implemented.")
def save_model(self):
def save_model(self) -> None:
def export_model(self):
def export_model(self) -> None:
def write_training_metrics(self):
def write_training_metrics(self) -> None:
"""
Write training metrics to a CSV file
:return:

self.summary_writer.add_summary(summary, step)
self.summary_writer.flush()
def write_tensorboard_text(self, key, input_dict):
def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None:
"""
Saves text to Tensorboard.
Note: Only works on tensorflow r1.2 or above.

"Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above."
)
pass
def add_experiences(
self,
curr_all_info: AllBrainInfo,
next_all_info: AllBrainInfo,
take_action_outputs: ActionInfoOutputs,
) -> None:
"""
Adds experiences to each agent's experience history.
:param curr_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param next_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param take_action_outputs: The outputs of the Policy's get_action method.
"""
raise UnityTrainerException(
"The process_experiences method was not implemented."
)
def process_experiences(
self, current_info: AllBrainInfo, next_info: AllBrainInfo
) -> None:
"""
Checks agent histories for processing condition, and processes them as necessary.
Processing involves calculating value and advantage targets for model updating step.
:param current_info: Dictionary of all current-step brains and corresponding BrainInfo.
:param next_info: Dictionary of all next-step brains and corresponding BrainInfo.
"""
raise UnityTrainerException(
"The process_experiences method was not implemented."
)
def end_episode(self):
"""
A signal that the Episode has ended. The buffer must be reset.
Get only called when the academy resets.
"""
raise UnityTrainerException("The end_episode method was not implemented.")
def is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
raise UnityTrainerException("The is_ready_update method was not implemented.")
def update_policy(self):
"""
Uses demonstration_buffer to update model.
"""
raise UnityTrainerException("The update_model method was not implemented.")

1
ml-agents/setup.py


"pyyaml",
"protobuf>=3.6,<3.7",
"grpcio>=1.11.0,<1.12.0",
"h5py==2.9.0",
'pypiwin32==223;platform_system=="Windows"',
],
python_requires=">=3.6,<3.7",

253
ml-agents/mlagents/trainers/rl_trainer.py


# # Unity ML-Agents Toolkit
import logging
from typing import Dict, List, Deque, Any
import os
import tensorflow as tf
import numpy as np
from collections import deque, defaultdict
from mlagents.envs import UnityException, AllBrainInfo, ActionInfoOutputs, BrainInfo
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.tf_policy import Policy
from mlagents.trainers.trainer import Trainer, UnityTrainerException
from mlagents.envs import BrainParameters
LOGGER = logging.getLogger("mlagents.trainers")
class RLTrainer(Trainer):
"""
This class is the base class for trainers that use Reward Signals.
Contains methods for adding BrainInfos to the Buffer.
"""
def __init__(self, *args, **kwargs):
super(RLTrainer, self).__init__(*args, **kwargs)
self.step = 0
# Make sure we have at least one reward_signal
if not self.trainer_parameters["reward_signals"]:
raise UnityTrainerException(
"No reward signals were defined. At least one must be used with {}.".format(
self.__class__.__name__
)
)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.collected_rewards = {"environment": {}}
self.training_buffer = Buffer()
self.episode_steps = {}
def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
"""
Constructs a BrainInfo which contains the most recent previous experiences for all agents
which correspond to the agents in a provided next_info.
:BrainInfo next_info: A t+1 BrainInfo.
:return: curr_info: Reconstructed BrainInfo to match agents of next_info.
"""
visual_observations: List[List[Any]] = [
[]
] # TODO add types to brain.py methods
vector_observations = []
text_observations = []
memories = []
rewards = []
local_dones = []
max_reacheds = []
agents = []
prev_vector_actions = []
prev_text_actions = []
action_masks = []
for agent_id in next_info.agents:
agent_brain_info = self.training_buffer[agent_id].last_brain_info
if agent_brain_info is None:
agent_brain_info = next_info
agent_index = agent_brain_info.agents.index(agent_id)
for i in range(len(next_info.visual_observations)):
visual_observations[i].append(
agent_brain_info.visual_observations[i][agent_index]
)
vector_observations.append(
agent_brain_info.vector_observations[agent_index]
)
text_observations.append(agent_brain_info.text_observations[agent_index])
if self.policy.use_recurrent:
if len(agent_brain_info.memories) > 0:
memories.append(agent_brain_info.memories[agent_index])
else:
memories.append(self.policy.make_empty_memory(1))
rewards.append(agent_brain_info.rewards[agent_index])
local_dones.append(agent_brain_info.local_done[agent_index])
max_reacheds.append(agent_brain_info.max_reached[agent_index])
agents.append(agent_brain_info.agents[agent_index])
prev_vector_actions.append(
agent_brain_info.previous_vector_actions[agent_index]
)
prev_text_actions.append(
agent_brain_info.previous_text_actions[agent_index]
)
action_masks.append(agent_brain_info.action_masks[agent_index])
if self.policy.use_recurrent:
memories = np.vstack(memories)
curr_info = BrainInfo(
visual_observations,
vector_observations,
text_observations,
memories,
rewards,
agents,
local_dones,
prev_vector_actions,
prev_text_actions,
max_reacheds,
action_masks,
)
return curr_info
def add_experiences(
self,
curr_all_info: AllBrainInfo,
next_all_info: AllBrainInfo,
take_action_outputs: ActionInfoOutputs,
) -> None:
"""
Adds experiences to each agent's experience history.
:param curr_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param next_all_info: Dictionary of all current brains and corresponding BrainInfo.
:param take_action_outputs: The outputs of the Policy's get_action method.
"""
self.trainer_metrics.start_experience_collection_timer()
if take_action_outputs:
self.stats["Policy/Entropy"].append(take_action_outputs["entropy"].mean())
self.stats["Policy/Learning Rate"].append(
take_action_outputs["learning_rate"]
)
for name, signal in self.policy.reward_signals.items():
self.stats[signal.value_name].append(
np.mean(take_action_outputs["value_heads"][name])
)
curr_info = curr_all_info[self.brain_name]
next_info = next_all_info[self.brain_name]
for agent_id in curr_info.agents:
self.training_buffer[agent_id].last_brain_info = curr_info
self.training_buffer[
agent_id
].last_take_action_outputs = take_action_outputs
if curr_info.agents != next_info.agents:
curr_to_use = self.construct_curr_info(next_info)
else:
curr_to_use = curr_info
tmp_rewards_dict = {}
for name, signal in self.policy.reward_signals.items():
tmp_rewards_dict[name] = signal.evaluate(curr_to_use, next_info)
for agent_id in next_info.agents:
stored_info = self.training_buffer[agent_id].last_brain_info
stored_take_action_outputs = self.training_buffer[
agent_id
].last_take_action_outputs
if stored_info is not None:
idx = stored_info.agents.index(agent_id)
next_idx = next_info.agents.index(agent_id)
if not stored_info.local_done[idx]:
for i, _ in enumerate(stored_info.visual_observations):
self.training_buffer[agent_id]["visual_obs%d" % i].append(
stored_info.visual_observations[i][idx]
)
self.training_buffer[agent_id]["next_visual_obs%d" % i].append(
next_info.visual_observations[i][next_idx]
)
if self.policy.use_vec_obs:
self.training_buffer[agent_id]["vector_obs"].append(
stored_info.vector_observations[idx]
)
self.training_buffer[agent_id]["next_vector_in"].append(
next_info.vector_observations[next_idx]
)
if self.policy.use_recurrent:
if stored_info.memories.shape[1] == 0:
stored_info.memories = np.zeros(
(len(stored_info.agents), self.policy.m_size)
)
self.training_buffer[agent_id]["memory"].append(
stored_info.memories[idx]
)
self.training_buffer[agent_id]["masks"].append(1.0)
self.training_buffer[agent_id]["done"].append(
next_info.local_done[next_idx]
)
# Add the outputs of the last eval
self.add_policy_outputs(stored_take_action_outputs, agent_id, idx)
# Store action masks if neccessary
if not self.policy.use_continuous_act:
self.training_buffer[agent_id]["action_mask"].append(
stored_info.action_masks[idx], padding_value=1
)
self.training_buffer[agent_id]["prev_action"].append(
stored_info.previous_vector_actions[idx]
)
values = stored_take_action_outputs["value_heads"]
# Add the value outputs if needed
self.add_rewards_outputs(
values, tmp_rewards_dict, agent_id, idx, next_idx
)
for name, rewards in self.collected_rewards.items():
if agent_id not in rewards:
rewards[agent_id] = 0
if name == "environment":
# Report the reward from the environment
rewards[agent_id] += np.array(next_info.rewards)[next_idx]
else:
# Report the reward signals
rewards[agent_id] += tmp_rewards_dict[name].scaled_reward[
next_idx
]
if not next_info.local_done[next_idx]:
if agent_id not in self.episode_steps:
self.episode_steps[agent_id] = 0
self.episode_steps[agent_id] += 1
self.trainer_metrics.end_experience_collection_timer()
def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int
) -> None:
"""
Takes the output of the last action and store it into the training buffer.
We break this out from add_experiences since it is very highly dependent
on the type of trainer.
:param take_action_outputs: The outputs of the Policy's get_action method.
:param agent_id: the Agent we're adding to.
:param agent_idx: the index of the Agent agent_id
"""
raise UnityTrainerException(
"The process_experiences method was not implemented."
)
def add_rewards_outputs(
self,
value: Dict[str, Any],
rewards_dict: Dict[str, float],
agent_id: str,
agent_idx: int,
agent_next_idx: int,
) -> None:
"""
Takes the value and evaluated rewards output of the last action and store it
into the training buffer. We break this out from add_experiences since it is very
highly dependent on the type of trainer.
:param take_action_outputs: The outputs of the Policy's get_action method.
:param rewards_dict: Dict of rewards after evaluation
:param agent_id: the Agent we're adding to.
:param agent_idx: the index of the Agent agent_id in the current brain info
:param agent_next_idx: the index of the Agent agent_id in the next brain info
"""
raise UnityTrainerException(
"The process_experiences method was not implemented."
)

81
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


import unittest.mock as mock
import pytest
import yaml
import mlagents.trainers.tests.mock_brain as mb
import numpy as np
from mlagents.trainers.rl_trainer import RLTrainer
@pytest.fixture
def dummy_config():
return yaml.safe_load(
"""
summary_path: "test/"
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
def create_mock_brain():
mock_brain = mb.create_mock_brainparams(
vector_action_space_type="continuous",
vector_action_space_size=[2],
vector_observation_space_size=8,
number_visual_observations=1,
)
return mock_brain
def create_rl_trainer():
mock_brainparams = create_mock_brain()
trainer = RLTrainer(mock_brainparams, dummy_config(), True, 0)
return trainer
def create_mock_all_brain_info(brain_info):
return {"MockBrain": brain_info}
def create_mock_policy():
mock_policy = mock.Mock()
mock_policy.reward_signals = {}
return mock_policy
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_policy_outputs")
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_rewards_outputs")
def test_rl_trainer(add_policy_outputs, add_rewards_outputs):
trainer = create_rl_trainer()
trainer.policy = create_mock_policy()
fake_action_outputs = {
"action": [0.1, 0.1],
"value_heads": {},
"entropy": np.array([1.0]),
"learning_rate": 1.0,
}
mock_braininfo = mb.create_mock_braininfo(
num_agents=2,
num_vector_observations=8,
num_vector_acts=2,
num_vis_observations=1,
)
trainer.add_experiences(
create_mock_all_brain_info(mock_braininfo),
create_mock_all_brain_info(mock_braininfo),
fake_action_outputs,
)
# Remove one of the agents
next_mock_braininfo = mb.create_mock_braininfo(
num_agents=1,
num_vector_observations=8,
num_vector_acts=2,
num_vis_observations=1,
)
brain_info = trainer.construct_curr_info(next_mock_braininfo)
# assert construct_curr_info worked properly
assert len(brain_info.agents) == 1
正在加载...
取消
保存