浏览代码

Move methods into common optimizer

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
2373cae8
共有 4 个文件被更改,包括 248 次插入122 次删除
  1. 20
      ml-agents/mlagents/trainers/models.py
  2. 114
      ml-agents/mlagents/trainers/optimizer.py
  3. 8
      ml-agents/mlagents/trainers/ppo/models.py
  4. 228
      ml-agents/mlagents/trainers/ppo/optimizer.py

20
ml-agents/mlagents/trainers/models.py


import logging
from enum import Enum
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from mlagents.tf_utils import tf

)
self.vector_in = LearningModel.create_vector_input(self.vec_obs_size)
if self.normalize:
self.update_normalization, self.normalization_steps, self.running_mean, self.running_variance = LearningModel.create_normalizer(
self.vector_in
)
normalization_tensors = LearningModel.create_normalizer(self.vector_in)
self.update_normalization = normalization_tensors[0]
self.normalization_steps = normalization_tensors[1]
self.running_mean = normalization_tensors[2]
self.running_variance = normalization_tensors[3]
self.processed_vector_in = self.normalize_vector_obs(self.vector_in)
else:
self.processed_vector_in = self.vector_in

return visual_in
@staticmethod
def create_vector_input(vec_obs_size: int, name="vector_observation"):
def create_vector_input(
vec_obs_size: int, name: str = "vector_observation"
) -> tf.Tensor:
"""
Creates ops for vector observation input.
:param name: Name of the placeholder op.

return normalized_state
@staticmethod
def create_normalizer(vector_obs: tf.Tensor):
def create_normalizer(
vector_obs: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
vec_obs_size = vector_obs.shape[1]
steps = tf.get_variable(
"normalization_steps",

steps: tf.Tensor,
running_mean: tf.Tensor,
running_variance: tf.Tensor,
):
) -> tf.Operation:
# Based on Welford's algorithm for running mean and standard deviation, for batch updates. Discussion here:
# https://stackoverflow.com/questions/56402955/whats-the-formula-for-welfords-algorithm-for-variance-std-with-batch-updates
steps_increment = tf.shape(vector_input)[0]

114
ml-agents/mlagents/trainers/optimizer.py


import abc
from typing import Dict, Any
from tf_utils import tf
from typing import Dict, Any, List
import numpy as np
from mlagents.tf_utils.tf import tf
from mlagents.trainers.policy import Policy
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.models import LearningModel
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
class Optimizer(abc.ABC):

"""
def __init__(self, policy: Policy, optimizer_parameters: Dict[str, Any]):
@abc.abstractmethod
def __init__(self, policy: LearningModel, optimizer_parameters: Dict[str, Any]):
pass
def update_batch(self, batch: AgentBuffer):
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
pass
class TFOptimizer(Optimizer):
def __init__(self, sess: tf.Session, policy: TFPolicy, reward_signal_configs):
class TFOptimizer(Optimizer, abc.ABC): # pylint: disable=W0223
def __init__(
self, sess: tf.Session, policy: PPOModel, reward_signal_configs: Dict[str, Any]
):
super().__init__(policy, reward_signal_configs)
def get_batched_value_estimates(self, batch: AgentBuffer) -> Dict[str, np.ndarray]:
feed_dict: Dict[tf.Tensor, Any] = {
self.policy.batch_size: batch.num_experiences,
self.policy.sequence_length: 1, # We want to feed data in batch-wise, not time-wise.
}
if self.policy.vec_obs_size > 0:
feed_dict[self.policy.vector_in] = batch["vector_obs"]
if self.policy.vis_obs_size > 0:
for i in range(len(self.policy.visual_in)):
_obs = batch["visual_obs%d" % i]
feed_dict[self.policy.visual_in[i]] = _obs
if self.policy.use_recurrent:
feed_dict[self.policy.memory_in] = batch["memory"]
if self.policy.prev_action is not None:
feed_dict[self.policy.prev_action] = batch["prev_action"]
value_estimates = self.sess.run(self.value_heads, feed_dict)
value_estimates = {k: np.squeeze(v, axis=1) for k, v in value_estimates.items()}
return value_estimates
def get_value_estimates(
self, next_obs: List[np.ndarray], agent_id: str, done: bool
) -> Dict[str, float]:
"""
Generates value estimates for bootstrapping.
:param experience: AgentExperience to be used for bootstrapping.
:param done: Whether or not this is the last element of the episode, in which case the value estimate will be 0.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
"""
feed_dict: Dict[tf.Tensor, Any] = {
self.policy.batch_size: 1,
self.policy.sequence_length: 1,
}
vec_vis_obs = SplitObservations.from_observations(next_obs)
for i in range(len(vec_vis_obs.visual_observations)):
feed_dict[self.policy.visual_in[i]] = [vec_vis_obs.visual_observations[i]]
if self.policy.vec_obs_size > 0:
feed_dict[self.policy.vector_in] = [vec_vis_obs.vector_observations]
# if self.policy.use_recurrent:
# feed_dict[self.policy.memory_in] = self.policy.retrieve_memories([agent_id])
# if self.policy.prev_action is not None:
# feed_dict[self.policy.prev_action] = self.policy.retrieve_previous_action(
# [agent_id]
# )
value_estimates = self.sess.run(self.value_heads, feed_dict)
value_estimates = {k: float(v) for k, v in value_estimates.items()}
# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if self.reward_signals[k].use_terminal_states:
value_estimates[k] = 0.0
return value_estimates
def create_reward_signals(self, reward_signal_configs):
"""
Create reward signals
:param reward_signal_configs: Reward signal config.
"""
self.reward_signals = {}
# Create reward signals
for reward_signal, config in reward_signal_configs.items():
self.reward_signals[reward_signal] = create_reward_signal(
self, self.policy, reward_signal, config
)
self.update_dict.update(self.reward_signals[reward_signal].update_dict)
def create_value_heads(self, stream_names, hidden_input):
"""
Creates one value estimator head for each reward signal in stream_names.
Also creates the node corresponding to the mean of all the value heads in self.value.
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal.
:param stream_names: The list of reward signal names
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
"""
for name in stream_names:
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name))
self.value_heads[name] = value
self.value = tf.reduce_mean(list(self.value_heads.values()), 0)

8
ml-agents/mlagents/trainers/ppo/models.py


:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
2, h_size, num_layers, vis_encode_type
self.visual_in,
self.processed_vector_in,
2,
h_size,
num_layers,
vis_encode_type,
)
if self.use_recurrent:

axis=1,
keepdims=True,
)

228
ml-agents/mlagents/trainers/ppo/optimizer.py


import logging
from typing import Optional, Dict, List, Any
from typing import Optional, Any, Dict
from mlagents_envs.timers import timed
from mlagents.trainers.optimizer import TFOptimizer
from mlagents.trainers.optimizer import TFOptimizer
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
from mlagents.trainers.ppo.models import PPOModel
logger = logging.getLogger("mlagents.trainers")

:param stream_names: List of names of value streams. Usually, a list of the Reward Signals being used.
:return: a sub-class of PPOAgent tailored to the environment.
"""
super().__init__(sess, policy, reward_signal_configs)
super().__init__(self, sess, self.policy)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
if num_layers < 1:
num_layers = 1
if brain.vector_action_space_type == "continuous":

self.learning_rate = self.create_learning_rate(
lr_schedule, lr, self.global_step, max_step
self.learning_rate = LearningModel.create_learning_rate(
lr_schedule, lr, self.policy.global_step, max_step
)
self.create_losses(
self.policy.log_probs,

if self.policy.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
shape=[None, self.policy.m_size], dtype=tf.float32, name="recurrent_in"
_half_point = int(self.m_size / 2)
_half_point = int(self.policy.m_size / 2)
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_value, memory_value_out = LearningModel.create_recurrent_encoder(
hidden_stream,
self.memory_in[:, _half_point:],
self.policy.sequence_length,

:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
2, h_size, num_layers, vis_encode_type
hidden_streams = LearningModel.create_observation_streams(
self.policy.visual_in,
self.policy.processed_vector_in,
2,
h_size,
num_layers,
vis_encode_type,
if self.use_recurrent:
if self.policy.use_recurrent:
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
shape=[None, len(self.policy.act_size)],
dtype=tf.int32,
name="prev_action",
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
tf.one_hot(self.prev_action[:, i], self.policy.act_size[i])
for i in range(len(self.policy.act_size))
],
axis=1,
)

shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
shape=[None, self.policy.m_size], dtype=tf.float32, name="recurrent_in"
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
_half_point = int(self.policy.m_size / 2)
hidden_policy, memory_policy_out = LearningModel.create_recurrent_encoder(
self.sequence_length,
self.policy.sequence_length,
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_value, memory_value_out = LearningModel.create_recurrent_encoder(
self.sequence_length,
self.policy.sequence_length,
name="lstm_value",
)
self.memory_out = tf.concat(

hidden_value = hidden_streams[1]
policy_branches = []
for size in self.act_size:
for size in self.policy.act_size:
policy_branches.append(
tf.layers.dense(
hidden_policy,

self.all_log_probs = tf.concat(policy_branches, axis=1, name="action_probs")
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
shape=[None, sum(self.policy.act_size)],
dtype=tf.float32,
name="action_masks",
output, _, normalized_logits = self.create_discrete_action_masking_layer(
self.all_log_probs, self.action_masks, self.act_size
output, _, normalized_logits = LearningModel.create_discrete_action_masking_layer(
self.all_log_probs, self.action_masks, self.policy.act_size
)
self.output = tf.identity(output)

)
self.action_oh = tf.concat(
[
tf.one_hot(self.action_holder[:, i], self.act_size[i])
for i in range(len(self.act_size))
tf.one_hot(self.action_holder[:, i], self.policy.act_size[i])
for i in range(len(self.policy.act_size))
],
axis=1,
)

shape=[None, sum(self.act_size)], dtype=tf.float32, name="old_probabilities"
shape=[None, sum(self.policy.act_size)],
dtype=tf.float32,
name="old_probabilities",
_, _, old_normalized_logits = self.create_discrete_action_masking_layer(
self.all_old_log_probs, self.action_masks, self.act_size
_, _, old_normalized_logits = LearningModel.create_discrete_action_masking_layer(
self.all_old_log_probs, self.action_masks, self.policy.act_size
action_idx = [0] + list(np.cumsum(self.act_size))
action_idx = [0] + list(np.cumsum(self.policy.act_size))
self.entropy = tf.reduce_sum(
(

:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
for i in range(len(self.policy.act_size))
],
axis=1,
)

:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
for i in range(len(self.policy.act_size))
],
axis=1,
)

:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
for i in range(len(self.policy.act_size))
],
axis=1,
)

advantage = tf.expand_dims(self.advantage, -1)
decay_epsilon = tf.train.polynomial_decay(
epsilon, self.global_step, max_step, 0.1, power=1.0
epsilon, self.policy.global_step, max_step, 0.1, power=1.0
beta, self.global_step, max_step, 1e-5, power=1.0
beta, self.policy.global_step, max_step, 1e-5, power=1.0
)
value_losses = []

self.returns_holders[name], clipped_value_estimate
)
value_loss = tf.reduce_mean(
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.policy.mask, 2)[
1
]
)
value_losses.append(value_loss)
self.value_loss = tf.reduce_mean(value_losses)

* advantage
)
self.policy_loss = -tf.reduce_mean(
tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.mask, 2)[1]
tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.policy.mask, 2)[1]
)
# For cleaner stats reporting
self.abs_policy_loss = tf.abs(self.policy_loss)

+ 0.5 * self.value_loss
- decay_beta
* tf.reduce_mean(tf.dynamic_partition(entropy, self.mask, 2)[1])
* tf.reduce_mean(tf.dynamic_partition(entropy, self.policy.mask, 2)[1])
)
def create_ppo_optimizer(self):

def get_batched_value_estimates(self, batch: AgentBuffer) -> Dict[str, np.ndarray]:
feed_dict: Dict[tf.Tensor, Any] = {
self.policy.batch_size: batch.num_experiences,
self.policy.sequence_length: 1, # We want to feed data in batch-wise, not time-wise.
}
if self.policy.vec_obs_size > 0:
feed_dict[self.policy.vector_in] = batch["vector_obs"]
if self.policy.vis_obs_size > 0:
for i in range(len(self.policy.visual_in)):
_obs = batch["visual_obs%d" % i]
feed_dict[self.policy.visual_in[i]] = _obs
if self.policy.use_recurrent:
feed_dict[self.policy.memory_in] = batch["memory"]
if self.policy.prev_action is not None:
feed_dict[self.policy.prev_action] = batch["prev_action"]
value_estimates = self.sess.run(self.value_heads, feed_dict)
value_estimates = {k: np.squeeze(v, axis=1) for k, v in value_estimates.items()}
return value_estimates
def get_value_estimates(
self, next_obs: List[np.ndarray], agent_id: str, done: bool
) -> Dict[str, float]:
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
Generates value estimates for bootstrapping.
:param experience: AgentExperience to be used for bootstrapping.
:param done: Whether or not this is the last element of the episode, in which case the value estimate will be 0.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
Performs update on model.
:param mini_batch: Batch of experiences.
:param num_sequences: Number of sequences to process.
:return: Results of update.
feed_dict: Dict[tf.Tensor, Any] = {
self.policy.batch_size: 1,
self.policy.sequence_length: 1,
}
vec_vis_obs = SplitObservations.from_observations(next_obs)
for i in range(len(vec_vis_obs.visual_observations)):
feed_dict[self.policy.visual_in[i]] = [vec_vis_obs.visual_observations[i]]
if self.policy.vec_obs_size > 0:
feed_dict[self.policy.vector_in] = [vec_vis_obs.vector_observations]
if self.policy.use_recurrent:
feed_dict[self.policy.memory_in] = self.retrieve_memories([agent_id])
if self.policy.prev_action is not None:
feed_dict[self.policy.prev_action] = self.retrieve_previous_action(
[agent_id]
feed_dict = self.construct_feed_dict(self.policy, batch, num_sequences)
stats_needed = self.stats_name_to_update_name
update_stats = {}
# Collect feed dicts for all reward signals.
for _, reward_signal in self.reward_signals.items():
feed_dict.update(
reward_signal.prepare_update(self.policy, batch, num_sequences)
value_estimates = self.sess.run(self.value_heads, feed_dict)
stats_needed.update(reward_signal.stats_name_to_update_name)
value_estimates = {k: float(v) for k, v in value_estimates.items()}
update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
return update_stats
# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if self.reward_signals[k].use_terminal_states:
value_estimates[k] = 0.0
def construct_feed_dict(
self, model: PPOModel, mini_batch: AgentBuffer, num_sequences: int
) -> Dict[tf.Tensor, Any]:
feed_dict = {
model.batch_size: num_sequences,
model.sequence_length: model.sequence_length,
model.mask_input: mini_batch["masks"],
self.advantage: mini_batch["advantages"],
self.all_old_log_probs: mini_batch["action_probs"],
}
for name in self.reward_signals:
feed_dict[self.returns_holders[name]] = mini_batch[
"{}_returns".format(name)
]
feed_dict[self.old_values[name]] = mini_batch[
"{}_value_estimates".format(name)
]
return value_estimates
if "actions_pre" in mini_batch:
feed_dict[model.output_pre] = mini_batch["actions_pre"]
else:
feed_dict[model.action_holder] = mini_batch["actions"]
if "prev_action" in mini_batch:
feed_dict[model.prev_action] = mini_batch["prev_action"]
feed_dict[model.action_masks] = mini_batch["action_mask"]
if "vector_obs" in mini_batch:
feed_dict[model.vector_in] = mini_batch["vector_obs"]
if model.vis_obs_size > 0:
for i, _ in enumerate(model.visual_in):
feed_dict[model.visual_in[i]] = mini_batch["visual_obs%d" % i]
if "memory" in mini_batch:
mem_in = [
mini_batch["memory"][i]
for i in range(
0, len(mini_batch["memory"]), self.policy.sequence_length
)
]
feed_dict[model.memory_in] = mem_in
return feed_dict
def create_reward_signals(self, reward_signal_configs):
def _execute_model(self, feed_dict, out_dict):
Create reward signals
:param reward_signal_configs: Reward signal config.
Executes model.
:param feed_dict: Input dictionary mapping nodes to input data.
:param out_dict: Output dictionary mapping names to nodes.
:return: Dictionary mapping names to input data.
self.reward_signals = {}
# Create reward signals
for reward_signal, config in reward_signal_configs.items():
self.reward_signals[reward_signal] = create_reward_signal(
self, self.policy, reward_signal, config
)
self.update_dict.update(self.reward_signals[reward_signal].update_dict)
network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict)
run_out = dict(zip(list(out_dict.keys()), network_out))
return run_out
正在加载...
取消
保存