浏览代码

initialize

/develop/model-transfer
yanchaosun 4 年前
当前提交
1b86b3ae
共有 4 个文件被更改,包括 642 次插入0 次删除
  1. 0
      ml-agents/mlagents/trainers/ppo_transfer/__init__.py
  2. 366
      ml-agents/mlagents/trainers/ppo_transfer/optimizer.py
  3. 276
      ml-agents/mlagents/trainers/ppo_transfer/trainer.py

0
ml-agents/mlagents/trainers/ppo_transfer/__init__.py

366
ml-agents/mlagents/trainers/ppo_transfer/optimizer.py


from typing import Optional, Any, Dict, cast
import numpy as np
from mlagents.tf_utils import tf
from mlagents_envs.timers import timed
from mlagents.trainers.models import ModelUtils, EncoderType
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.settings import TrainerSettings, PPOSettings
class PPOOptimizer(TFOptimizer):
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings):
"""
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy.
The PPO optimizer has a value estimator and a loss function.
:param policy: A TFPolicy object that will be updated by this PPO Optimizer.
:param trainer_params: Trainer parameters dictionary that specifies the properties of the trainer.
"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
policy.create_tf_graph()
with policy.graph.as_default():
with tf.variable_scope("optimizer/"):
super().__init__(policy, trainer_params)
hyperparameters: PPOSettings = cast(
PPOSettings, trainer_params.hyperparameters
)
lr = float(hyperparameters.learning_rate)
self._schedule = hyperparameters.learning_rate_schedule
epsilon = float(hyperparameters.epsilon)
beta = float(hyperparameters.beta)
max_step = float(trainer_params.max_steps)
policy_network_settings = policy.network_settings
h_size = int(policy_network_settings.hidden_units)
num_layers = policy_network_settings.num_layers
vis_encode_type = policy_network_settings.vis_encode_type
self.burn_in_ratio = 0.0
self.stream_names = list(self.reward_signals.keys())
self.tf_optimizer: Optional[tf.train.AdamOptimizer] = None
self.grads = None
self.update_batch: Optional[tf.Operation] = None
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
"Policy/Learning Rate": "learning_rate",
"Policy/Epsilon": "decay_epsilon",
"Policy/Beta": "decay_beta",
}
if self.policy.use_recurrent:
self.m_size = self.policy.m_size
self.memory_in = tf.placeholder(
shape=[None, self.m_size],
dtype=tf.float32,
name="recurrent_value_in",
)
if num_layers < 1:
num_layers = 1
if policy.use_continuous_act:
self._create_cc_critic(h_size, num_layers, vis_encode_type)
else:
self._create_dc_critic(h_size, num_layers, vis_encode_type)
self.learning_rate = ModelUtils.create_schedule(
self._schedule,
lr,
self.policy.global_step,
int(max_step),
min_value=1e-10,
)
self._create_losses(
self.policy.total_log_probs,
self.old_log_probs,
self.value_heads,
self.policy.entropy,
beta,
epsilon,
lr,
max_step,
)
self._create_ppo_optimizer_ops()
self.update_dict.update(
{
"value_loss": self.value_loss,
"policy_loss": self.abs_policy_loss,
"update_batch": self.update_batch,
"learning_rate": self.learning_rate,
"decay_epsilon": self.decay_epsilon,
"decay_beta": self.decay_beta,
}
)
self.policy.initialize_or_load()
def _create_cc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Continuous control critic (value) network.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: The type of visual encoder to use.
"""
hidden_stream = ModelUtils.create_observation_streams(
self.policy.visual_in,
self.policy.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
if self.policy.use_recurrent:
hidden_value, memory_value_out = ModelUtils.create_recurrent_encoder(
hidden_stream,
self.memory_in,
self.policy.sequence_length_ph,
name="lstm_value",
)
self.memory_out = memory_value_out
else:
hidden_value = hidden_stream
self.value_heads, self.value = ModelUtils.create_value_heads(
self.stream_names, hidden_value
)
self.all_old_log_probs = tf.placeholder(
shape=[None, sum(self.policy.act_size)],
dtype=tf.float32,
name="old_probabilities",
)
self.old_log_probs = tf.reduce_sum(
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True
)
def _create_dc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Discrete control critic (value) network.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: The type of visual encoder to use.
"""
hidden_stream = ModelUtils.create_observation_streams(
self.policy.visual_in,
self.policy.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
if self.policy.use_recurrent:
hidden_value, memory_value_out = ModelUtils.create_recurrent_encoder(
hidden_stream,
self.memory_in,
self.policy.sequence_length_ph,
name="lstm_value",
)
self.memory_out = memory_value_out
else:
hidden_value = hidden_stream
self.value_heads, self.value = ModelUtils.create_value_heads(
self.stream_names, hidden_value
)
self.all_old_log_probs = tf.placeholder(
shape=[None, sum(self.policy.act_size)],
dtype=tf.float32,
name="old_probabilities",
)
# Break old log probs into separate branches
old_log_prob_branches = ModelUtils.break_into_branches(
self.all_old_log_probs, self.policy.act_size
)
_, _, old_normalized_logits = ModelUtils.create_discrete_action_masking_layer(
old_log_prob_branches, self.policy.action_masks, self.policy.act_size
)
action_idx = [0] + list(np.cumsum(self.policy.act_size))
self.old_log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.policy.selected_actions[
:, action_idx[i] : action_idx[i + 1]
],
logits=old_normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.policy.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
def _create_losses(
self, probs, old_probs, value_heads, entropy, beta, epsilon, lr, max_step
):
"""
Creates training-specific Tensorflow ops for PPO models.
:param probs: Current policy probabilities
:param old_probs: Past policy probabilities
:param value_heads: Value estimate tensors from each value stream
:param beta: Entropy regularization strength
:param entropy: Current policy entropy
:param epsilon: Value for policy-divergence threshold
:param lr: Learning rate
:param max_step: Total number of training steps.
"""
self.returns_holders = {}
self.old_values = {}
for name in value_heads.keys():
returns_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_returns".format(name)
)
old_value = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_value_estimate".format(name)
)
self.returns_holders[name] = returns_holder
self.old_values[name] = old_value
self.advantage = tf.placeholder(
shape=[None], dtype=tf.float32, name="advantages"
)
advantage = tf.expand_dims(self.advantage, -1)
self.decay_epsilon = ModelUtils.create_schedule(
self._schedule, epsilon, self.policy.global_step, max_step, min_value=0.1
)
self.decay_beta = ModelUtils.create_schedule(
self._schedule, beta, self.policy.global_step, max_step, min_value=1e-5
)
value_losses = []
for name, head in value_heads.items():
clipped_value_estimate = self.old_values[name] + tf.clip_by_value(
tf.reduce_sum(head, axis=1) - self.old_values[name],
-self.decay_epsilon,
self.decay_epsilon,
)
v_opt_a = tf.squared_difference(
self.returns_holders[name], tf.reduce_sum(head, axis=1)
)
v_opt_b = tf.squared_difference(
self.returns_holders[name], clipped_value_estimate
)
value_loss = tf.reduce_mean(
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.policy.mask, 2)[
1
]
)
value_losses.append(value_loss)
self.value_loss = tf.reduce_mean(value_losses)
r_theta = tf.exp(probs - old_probs)
p_opt_a = r_theta * advantage
p_opt_b = (
tf.clip_by_value(
r_theta, 1.0 - self.decay_epsilon, 1.0 + self.decay_epsilon
)
* advantage
)
self.policy_loss = -tf.reduce_mean(
tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.policy.mask, 2)[1]
)
# For cleaner stats reporting
self.abs_policy_loss = tf.abs(self.policy_loss)
self.loss = (
self.policy_loss
+ 0.5 * self.value_loss
- self.decay_beta
* tf.reduce_mean(tf.dynamic_partition(entropy, self.policy.mask, 2)[1])
)
def _create_ppo_optimizer_ops(self):
self.tf_optimizer = self.create_optimizer_op(self.learning_rate)
self.grads = self.tf_optimizer.compute_gradients(self.loss)
self.update_batch = self.tf_optimizer.minimize(self.loss)
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Performs update on model.
:param mini_batch: Batch of experiences.
:param num_sequences: Number of sequences to process.
:return: Results of update.
"""
feed_dict = self._construct_feed_dict(batch, num_sequences)
stats_needed = self.stats_name_to_update_name
update_stats = {}
# Collect feed dicts for all reward signals.
for _, reward_signal in self.reward_signals.items():
feed_dict.update(
reward_signal.prepare_update(self.policy, batch, num_sequences)
)
stats_needed.update(reward_signal.stats_name_to_update_name)
update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
return update_stats
def _construct_feed_dict(
self, mini_batch: AgentBuffer, num_sequences: int
) -> Dict[tf.Tensor, Any]:
# Do an optional burn-in for memories
num_burn_in = int(self.burn_in_ratio * self.policy.sequence_length)
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32)
burn_in_mask[range(0, num_burn_in)] = 0
burn_in_mask = np.tile(burn_in_mask, num_sequences)
feed_dict = {
self.policy.batch_size_ph: num_sequences,
self.policy.sequence_length_ph: self.policy.sequence_length,
self.policy.mask_input: mini_batch["masks"] * burn_in_mask,
self.advantage: mini_batch["advantages"],
self.all_old_log_probs: mini_batch["action_probs"],
}
for name in self.reward_signals:
feed_dict[self.returns_holders[name]] = mini_batch[
"{}_returns".format(name)
]
feed_dict[self.old_values[name]] = mini_batch[
"{}_value_estimates".format(name)
]
if self.policy.output_pre is not None and "actions_pre" in mini_batch:
feed_dict[self.policy.output_pre] = mini_batch["actions_pre"]
else:
feed_dict[self.policy.output] = mini_batch["actions"]
if self.policy.use_recurrent:
feed_dict[self.policy.prev_action] = mini_batch["prev_action"]
feed_dict[self.policy.action_masks] = mini_batch["action_mask"]
if "vector_obs" in mini_batch:
feed_dict[self.policy.vector_in] = mini_batch["vector_obs"]
if self.policy.vis_obs_size > 0:
for i, _ in enumerate(self.policy.visual_in):
feed_dict[self.policy.visual_in[i]] = mini_batch["visual_obs%d" % i]
if self.policy.use_recurrent:
feed_dict[self.policy.memory_in] = [
mini_batch["memory"][i]
for i in range(
0, len(mini_batch["memory"]), self.policy.sequence_length
)
]
feed_dict[self.memory_in] = self._make_zero_mem(
self.m_size, mini_batch.num_experiences
)
return feed_dict

276
ml-agents/mlagents/trainers/ppo_transfer/trainer.py


# # Unity ML-Agents Toolkit
# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347
from collections import defaultdict
from typing import cast
import numpy as np
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, PPOSettings
logger = get_logger(__name__)
class PPOTrainer(RLTrainer):
"""The PPOTrainer is an implementation of the PPO algorithm."""
def __init__(
self,
brain_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
load: bool,
seed: int,
artifact_path: str,
):
"""
Responsible for collecting experiences and training PPO model.
:param brain_name: The name of the brain associated with trainer config
:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_settings: The parameters for the trainer.
:param training: Whether the trainer is set for training.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super(PPOTrainer, self).__init__(
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)
self.hyperparameters: PPOSettings = cast(
PPOSettings, self.trainer_settings.hyperparameters
)
self.load = load
self.seed = seed
self.policy: NNPolicy = None # type: ignore
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the update buffer.
Processing involves calculating value and advantage targets for model updating step.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
super()._process_trajectory(trajectory)
agent_id = trajectory.agent_id # All the agents should have the same ID
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization
if self.is_training:
self.policy.update_normalization(agent_buffer_trajectory["vector_obs"])
# Get all value estimates
value_estimates, value_next = self.optimizer.get_trajectory_value_estimates(
agent_buffer_trajectory,
trajectory.next_obs,
trajectory.done_reached and not trajectory.interrupted,
)
for name, v in value_estimates.items():
agent_buffer_trajectory["{}_value_estimates".format(name)].extend(v)
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)
# Evaluate all reward functions
self.collected_rewards["environment"][agent_id] += np.sum(
agent_buffer_trajectory["environment_rewards"]
)
for name, reward_signal in self.optimizer.reward_signals.items():
evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward
agent_buffer_trajectory["{}_rewards".format(name)].extend(evaluate_result)
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)
# Compute GAE and returns
tmp_advantages = []
tmp_returns = []
for name in self.optimizer.reward_signals:
bootstrap_value = value_next[name]
local_rewards = agent_buffer_trajectory[
"{}_rewards".format(name)
].get_batch()
local_value_estimates = agent_buffer_trajectory[
"{}_value_estimates".format(name)
].get_batch()
local_advantage = get_gae(
rewards=local_rewards,
value_estimates=local_value_estimates,
value_next=bootstrap_value,
gamma=self.optimizer.reward_signals[name].gamma,
lambd=self.hyperparameters.lambd,
)
local_return = local_advantage + local_value_estimates
# This is later use as target for the different value estimates
agent_buffer_trajectory["{}_returns".format(name)].set(local_return)
agent_buffer_trajectory["{}_advantage".format(name)].set(local_advantage)
tmp_advantages.append(local_advantage)
tmp_returns.append(local_return)
# Get global advantages
global_advantages = list(
np.mean(np.array(tmp_advantages, dtype=np.float32), axis=0)
)
global_returns = list(np.mean(np.array(tmp_returns, dtype=np.float32), axis=0))
agent_buffer_trajectory["advantages"].set(global_advantages)
agent_buffer_trajectory["discounted_returns"].set(global_returns)
# Append to update buffer
agent_buffer_trajectory.resequence_and_append(
self.update_buffer, training_length=self.policy.sequence_length
)
# If this was a terminal trajectory, append stats and reset reward collection
if trajectory.done_reached:
self._update_end_episode_stats(agent_id, self.optimizer)
def _is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to whether or not update_model() can be run
"""
size_of_buffer = self.update_buffer.num_experiences
return size_of_buffer > self.hyperparameters.buffer_size
def _update_policy(self):
"""
Uses demonstration_buffer to update the policy.
The reward signal generators must be updated in this method at their own pace.
"""
buffer_length = self.update_buffer.num_experiences
self.cumulative_returns_since_policy_update.clear()
# Make sure batch_size is a multiple of sequence length. During training, we
# will need to reshape the data into a batch_size x sequence_length tensor.
batch_size = (
self.hyperparameters.batch_size
- self.hyperparameters.batch_size % self.policy.sequence_length
)
# Make sure there is at least one sequence
batch_size = max(batch_size, self.policy.sequence_length)
n_sequences = max(
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
advantages = self.update_buffer["advantages"].get_batch()
self.update_buffer["advantages"].set(
(advantages - advantages.mean()) / (advantages.std() + 1e-10)
)
num_epoch = self.hyperparameters.num_epoch
batch_update_stats = defaultdict(list)
for _ in range(num_epoch):
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer
max_num_batch = buffer_length // batch_size
for i in range(0, max_num_batch * batch_size, batch_size):
update_stats = self.optimizer.update(
buffer.make_mini_batch(i, i + batch_size), n_sequences
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
for stat, stat_list in batch_update_stats.items():
self._stats_reporter.add_stat(stat, np.mean(stat_list))
if self.optimizer.bc_module:
update_stats = self.optimizer.bc_module.update()
for stat, val in update_stats.items():
self._stats_reporter.add_stat(stat, val)
self._clear_update_buffer()
return True
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> TFPolicy:
"""
Creates a PPO policy to trainers list of policies.
:param brain_parameters: specifications for policy construction
:return policy
"""
policy = NNPolicy(
self.seed,
brain_parameters,
self.trainer_settings,
self.is_training,
self.artifact_path,
self.load,
condition_sigma_on_obs=False, # Faster training for PPO
create_tf_graph=False, # We will create the TF graph in the Optimizer
)
return policy
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
) -> None:
"""
Adds policy to trainer.
:param parsed_behavior_id: Behavior identifiers that the policy should belong to.
:param policy: Policy to associate with name_behavior_id.
"""
if self.policy:
logger.warning(
"Your environment contains multiple teams, but {} doesn't support adversarial games. Enable self-play to \
train adversarial games.".format(
self.__class__.__name__
)
)
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
self.optimizer = PPOOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""
Gets policy from trainer associated with name_behavior_id
:param name_behavior_id: full identifier of policy
"""
return self.policy
def discount_rewards(r, gamma=0.99, value_next=0.0):
"""
Computes discounted sum of future rewards for use in updating value estimate.
:param r: List of rewards.
:param gamma: Discount factor.
:param value_next: T+1 value estimate for returns calculation.
:return: discounted sum of future rewards as list.
"""
discounted_r = np.zeros_like(r)
running_add = value_next
for t in reversed(range(0, r.size)):
running_add = running_add * gamma + r[t]
discounted_r[t] = running_add
return discounted_r
def get_gae(rewards, value_estimates, value_next=0.0, gamma=0.99, lambd=0.95):
"""
Computes generalized advantage estimate for use in updating policy.
:param rewards: list of rewards for time-steps t to T.
:param value_next: Value estimate for time-step T+1.
:param value_estimates: list of value estimates for time-steps t to T.
:param gamma: Discount factor.
:param lambd: GAE weighing factor.
:return: list of advantage estimates for time-steps t to T.
"""
value_estimates = np.append(value_estimates, value_next)
delta_t = rewards + gamma * value_estimates[1:] - value_estimates[:-1]
advantage = discount_rewards(r=delta_t, gamma=gamma * lambd)
return advantage
正在加载...
取消
保存