浏览代码

Move some functionality to optimizer

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
03c750a7
共有 3 个文件被更改,包括 519 次插入71 次删除
  1. 150
      ml-agents/mlagents/trainers/models.py
  2. 37
      ml-agents/mlagents/trainers/ppo/models.py
  3. 403
      ml-agents/mlagents/trainers/ppo/optimizer.py

150
ml-agents/mlagents/trainers/models.py


):
tf.set_random_seed(seed)
self.brain = brain
self.vector_in = None
self.normalize = normalize
self.act_size = brain.vector_action_space_size
self.vec_obs_size = brain.vector_observation_space_size
self.vis_obs_size = brain.number_visual_observations
self.value_heads: Dict[str, tf.Tensor] = {}
self.normalization_steps: Optional[tf.Variable] = None
self.running_mean: Optional[tf.Variable] = None
self.running_variance: Optional[tf.Variable] = None
self.update_normalization: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: Optional[tf.Tensor] = None
self.output: Optional[tf.Tensor] = None
self.selected_actions: Optional[tf.Tensor] = None
self.action_holder: Optional[tf.Tensor] = None
self.visual_in = []
self.visual_in = LearningModel.create_visual_input_placeholders(
brain.camera_resolutions
)
self.vector_in = LearningModel.create_vector_input(self.vec_obs_size)
if self.normalize:
self.update_normalization, self.normalization_steps, self.running_mean, self.running_variance = LearningModel.create_normalizer(
self.vector_in
)
self.processed_vector_in = self.normalize_vector_obs(self.vector_in)
else:
self.processed_vector_in = self.vector_in
self.update_normalization = None
self.batch_size = tf.placeholder(shape=None, dtype=tf.int32, name="batch_size")
self.sequence_length = tf.placeholder(
shape=None, dtype=tf.int32, name="sequence_length"

self.m_size = m_size
else:
self.m_size = 0
self.normalize = normalize
self.act_size = brain.vector_action_space_size
self.vec_obs_size = brain.vector_observation_space_size
self.vis_obs_size = brain.number_visual_observations
tf.Variable(
int(brain.vector_action_space_type == "continuous"),
name="is_continuous_control",

trainable=False,
dtype=tf.int32,
)
self.value_heads: Dict[str, tf.Tensor] = {}
self.normalization_steps: Optional[tf.Variable] = None
self.running_mean: Optional[tf.Variable] = None
self.running_variance: Optional[tf.Variable] = None
self.update_normalization: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: Optional[tf.Tensor] = None
self.output: Optional[tf.Tensor] = None
self.selected_actions: Optional[tf.Tensor] = None
self.action_holder: Optional[tf.Tensor] = None
@staticmethod
def create_global_steps():

)
return visual_in
def create_vector_input(self, name="vector_observation"):
@staticmethod
def create_visual_input_placeholders(
camera_resolutions: List[CameraResolution]
) -> List[tf.Tensor]:
visual_in: List[tf.Tensor] = []
for i, camera_resolution in enumerate(camera_resolutions):
visual_input = LearningModel.create_visual_input(
camera_resolution, name="visual_observation_" + str(i)
)
visual_in.append(visual_input)
return visual_in
@staticmethod
def create_vector_input(vec_obs_size: int, name="vector_observation"):
"""
Creates ops for vector observation input.
:param name: Name of the placeholder op.

self.vector_in = tf.placeholder(
shape=[None, self.vec_obs_size], dtype=tf.float32, name=name
vector_in = tf.placeholder(
shape=[None, vec_obs_size], dtype=tf.float32, name=name
if self.normalize:
self.create_normalizer(self.vector_in)
return self.normalize_vector_obs(self.vector_in)
else:
return self.vector_in
return vector_in
def normalize_vector_obs(self, vector_obs):
normalized_state = tf.clip_by_value(

)
return normalized_state
def create_normalizer(self, vector_obs):
self.normalization_steps = tf.get_variable(
@staticmethod
def create_normalizer(vector_obs: tf.Tensor):
vec_obs_size = vector_obs.shape[1]
steps = tf.get_variable(
"normalization_steps",
[],
trainable=False,

self.running_mean = tf.get_variable(
running_mean = tf.get_variable(
[self.vec_obs_size],
[vec_obs_size],
self.running_variance = tf.get_variable(
running_variance = tf.get_variable(
[self.vec_obs_size],
[vec_obs_size],
self.update_normalization = self.create_normalizer_update(vector_obs)
update_normalization = LearningModel.create_normalizer_update(
vector_obs, steps, running_mean, running_variance
)
return update_normalization, steps, running_mean, running_variance
def create_normalizer_update(self, vector_input):
@staticmethod
def create_normalizer_update(
vector_input: tf.Tensor,
steps: tf.Tensor,
running_mean: tf.Tensor,
running_variance: tf.Tensor,
):
total_new_steps = tf.add(self.normalization_steps, steps_increment)
total_new_steps = tf.add(steps, steps_increment)
input_to_old_mean = tf.subtract(vector_input, self.running_mean)
new_mean = self.running_mean + tf.reduce_sum(
input_to_old_mean = tf.subtract(vector_input, running_mean)
new_mean = running_mean + tf.reduce_sum(
new_variance = self.running_variance + tf.reduce_sum(
new_variance = running_variance + tf.reduce_sum(
update_mean = tf.assign(self.running_mean, new_mean)
update_variance = tf.assign(self.running_variance, new_variance)
update_norm_step = tf.assign(self.normalization_steps, total_new_steps)
update_mean = tf.assign(running_mean, new_mean)
update_variance = tf.assign(running_variance, new_variance)
update_norm_step = tf.assign(steps, total_new_steps)
return tf.group([update_mean, update_variance, update_norm_step])
@staticmethod

@staticmethod
def _check_resolution_for_encoder(
camera_res: CameraResolution, vis_encoder_type: EncoderType
vis_in: tf.Tensor, vis_encoder_type: EncoderType
if camera_res.height < min_res or camera_res.width < min_res:
height = vis_in.shape[1]
width = vis_in.shape[2]
if height < min_res or width < min_res:
f"Visual observation resolution ({camera_res.width}x{camera_res.height}) is too small for"
f"Visual observation resolution ({width}x{height}) is too small for"
@staticmethod
self,
visual_in: List[tf.Tensor],
vector_in: tf.Tensor,
num_streams: int,
h_size: int,
num_layers: int,

the scopes for each of the streams. None if all under the same TF scope.
:return: List of encoded streams.
"""
brain = self.brain
activation_fn = self.swish
self.visual_in = []
for i in range(brain.number_visual_observations):
LearningModel._check_resolution_for_encoder(
brain.camera_resolutions[i], vis_encode_type
)
visual_input = self.create_visual_input(
brain.camera_resolutions[i], name="visual_observation_" + str(i)
)
self.visual_in.append(visual_input)
vector_observation_input = self.create_vector_input()
activation_fn = LearningModel.swish
vector_observation_input = vector_in
final_hiddens = []
for i in range(num_streams):

visual_encoders = []
hidden_state, hidden_visual = None, None
_scope_add = stream_scopes[i] if stream_scopes else ""
if self.vis_obs_size > 0:
for j in range(brain.number_visual_observations):
if len(visual_in) > 0:
for j, vis_in in enumerate(visual_in):
LearningModel._check_resolution_for_encoder(vis_in, vis_encode_type)
self.visual_in[j],
vis_in,
h_size,
activation_fn,
num_layers,

visual_encoders.append(encoded_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
if brain.vector_observation_space_size > 0:
hidden_state = self.create_vector_observation_encoder(
if vector_in is not None:
hidden_state = LearningModel.create_vector_observation_encoder(
vector_observation_input,
h_size,
activation_fn,

37
ml-agents/mlagents/trainers/ppo/models.py


self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy
else:
self.create_dc_actor_critic(h_size, num_layers, vis_encode_type)
self.create_losses(
self.log_probs,
self.old_log_probs,

:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
2, h_size, num_layers, vis_encode_type
hidden_streams = LearningModel.create_observation_streams(
self.visual_in, self.processed_vector_in, 2, h_size, num_layers, vis_encode_type
)
if self.use_recurrent:

:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
1, h_size, num_layers, vis_encode_type
2, h_size, num_layers, vis_encode_type
hidden = hidden_streams[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(

],
axis=1,
)
hidden = tf.concat([hidden, prev_action_oh], axis=1)
hidden_policy = tf.concat([hidden_streams[0], prev_action_oh], axis=1)
hidden, memory_out = self.create_recurrent_encoder(
hidden, self.memory_in, self.sequence_length
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_policy,
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_streams[1],
self.memory_in[:, _half_point:],
self.sequence_length,
name="lstm_value",
self.memory_out = tf.identity(memory_out, name="recurrent_out")
self.memory_out = tf.concat(
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out"
)
else:
hidden_policy = hidden_streams[0]
hidden_value = hidden_streams[1]
hidden,
hidden_policy,
size,
activation=None,
use_bias=False,

self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
self.create_value_heads(self.stream_names, hidden)
self.create_value_heads(self.stream_names, hidden_value)
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"

403
ml-agents/mlagents/trainers/ppo/optimizer.py


import logging
from typing import Optional
import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.models import LearningModel, EncoderType, LearningRateSchedule
logger = logging.getLogger("mlagents.trainers")
class PPOOptimizer(LearningModel):
def __init__(
self,
brain,
policy,
lr=1e-4,
lr_schedule=LearningRateSchedule.LINEAR,
h_size=128,
epsilon=0.2,
beta=1e-3,
max_step=5e6,
normalize=False,
use_recurrent=False,
num_layers=2,
m_size=None,
seed=0,
stream_names=None,
vis_encode_type=EncoderType.SIMPLE
):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.
:param brain: brain parameters used to generate specific network graph.
:param lr: Learning rate.
:param lr_schedule: Learning rate decay schedule.
:param h_size: Size of hidden layers
:param epsilon: Value for policy-divergence threshold.
:param beta: Strength of entropy regularization.
:param max_step: Total number of training steps.
:param normalize: Whether to normalize vector observation input.
:param use_recurrent: Whether to use an LSTM layer in the network.
:param num_layers Number of hidden layers between encoded input and policy & value layers
:param m_size: Size of brain memory.
:param seed: Seed to use for initialization of model.
:param stream_names: List of names of value streams. Usually, a list of the Reward Signals being used.
:return: a sub-class of PPOAgent tailored to the environment.
"""
LearningModel.__init__(
self, m_size, normalize, use_recurrent, brain, seed, stream_names
)
self.optimizer: Optional[tf.train.AdamOptimizer] = None
self.grads = None
self.update_batch: Optional[tf.Operation] = None
self.policy = policy
if num_layers < 1:
num_layers = 1
if brain.vector_action_space_type == "continuous":
self.create_cc_critic(h_size, num_layers, vis_encode_type)
self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy
else:
self.create_dc_actor_critic(h_size, num_layers, vis_encode_type)
self.learning_rate = self.create_learning_rate(
lr_schedule, lr, self.global_step, max_step
)
self.vector_in = self.policy.vector_in
self.visual_in = self.policy.visual_in
self.create_losses(
self.log_probs,
self.old_log_probs,
self.value_heads,
self.entropy,
beta,
epsilon,
lr,
max_step,
)
def create_cc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = LearningModel.create_observation_streams(
self.policy.visual_in, self.policy.processed_vector_in, 1, h_size, num_layers, vis_encode_type
)
if self.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_streams[0],
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_streams[1],
self.memory_in[:, _half_point:],
self.sequence_length,
name="lstm_value",
)
self.memory_out = tf.concat(
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out"
)
else:
hidden_policy = hidden_streams[0]
hidden_value = hidden_streams[1]
mu = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
kernel_initializer=LearningModel.scaled_init(0.01),
reuse=tf.AUTO_REUSE,
)
self.log_sigma_sq = tf.get_variable(
"log_sigma_squared",
[self.act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
sigma_sq = tf.exp(self.log_sigma_sq)
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"
)
# Clip and scale output to ensure actions are always within [-1, 1] range.
self.output_pre = mu + tf.sqrt(sigma_sq) * self.epsilon
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")
self.selected_actions = tf.stop_gradient(output_post)
# Compute probability of model output.
all_probs = (
-0.5 * tf.square(tf.stop_gradient(self.output_pre) - mu) / sigma_sq
- 0.5 * tf.log(2.0 * np.pi)
- 0.5 * self.log_sigma_sq
)
self.all_log_probs = tf.identity(all_probs, name="action_probs")
self.entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + self.log_sigma_sq
)
self.create_value_heads(self.stream_names, hidden_value)
self.all_old_log_probs = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="old_probabilities"
)
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
self.log_probs = tf.reduce_sum(
(tf.identity(self.all_log_probs)), axis=1, keepdims=True
)
self.old_log_probs = tf.reduce_sum(
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True
)
def create_dc_actor_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
2, h_size, num_layers, vis_encode_type
)
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
hidden_policy = tf.concat([hidden_streams[0], prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_policy,
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_streams[1],
self.memory_in[:, _half_point:],
self.sequence_length,
name="lstm_value",
)
self.memory_out = tf.concat(
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out"
)
else:
hidden_policy = hidden_streams[0]
hidden_value = hidden_streams[1]
policy_branches = []
for size in self.act_size:
policy_branches.append(
tf.layers.dense(
hidden_policy,
size,
activation=None,
use_bias=False,
kernel_initializer=LearningModel.scaled_init(0.01),
)
)
self.all_log_probs = tf.concat(policy_branches, axis=1, name="action_probs")
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
output, _, normalized_logits = self.create_discrete_action_masking_layer(
self.all_log_probs, self.action_masks, self.act_size
)
self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
self.create_value_heads(self.stream_names, hidden_value)
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"
)
self.action_oh = tf.concat(
[
tf.one_hot(self.action_holder[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
self.selected_actions = tf.stop_gradient(self.action_oh)
self.all_old_log_probs = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="old_probabilities"
)
_, _, old_normalized_logits = self.create_discrete_action_masking_layer(
self.all_old_log_probs, self.action_masks, self.act_size
)
action_idx = [0] + list(np.cumsum(self.act_size))
self.entropy = tf.reduce_sum(
(
tf.stack(
[
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.nn.softmax(
self.all_log_probs[:, action_idx[i] : action_idx[i + 1]]
),
logits=self.all_log_probs[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
)
self.log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
self.old_log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=old_normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
def create_losses(
self, probs, old_probs, value_heads, entropy, beta, epsilon, lr, max_step
):
"""
Creates training-specific Tensorflow ops for PPO models.
:param probs: Current policy probabilities
:param old_probs: Past policy probabilities
:param value_heads: Value estimate tensors from each value stream
:param beta: Entropy regularization strength
:param entropy: Current policy entropy
:param epsilon: Value for policy-divergence threshold
:param lr: Learning rate
:param max_step: Total number of training steps.
"""
self.returns_holders = {}
self.old_values = {}
for name in value_heads.keys():
returns_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_returns".format(name)
)
old_value = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_value_estimate".format(name)
)
self.returns_holders[name] = returns_holder
self.old_values[name] = old_value
self.advantage = tf.placeholder(
shape=[None], dtype=tf.float32, name="advantages"
)
advantage = tf.expand_dims(self.advantage, -1)
decay_epsilon = tf.train.polynomial_decay(
epsilon, self.global_step, max_step, 0.1, power=1.0
)
decay_beta = tf.train.polynomial_decay(
beta, self.global_step, max_step, 1e-5, power=1.0
)
value_losses = []
for name, head in value_heads.items():
clipped_value_estimate = self.old_values[name] + tf.clip_by_value(
tf.reduce_sum(head, axis=1) - self.old_values[name],
-decay_epsilon,
decay_epsilon,
)
v_opt_a = tf.squared_difference(
self.returns_holders[name], tf.reduce_sum(head, axis=1)
)
v_opt_b = tf.squared_difference(
self.returns_holders[name], clipped_value_estimate
)
value_loss = tf.reduce_mean(
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1]
)
value_losses.append(value_loss)
self.value_loss = tf.reduce_mean(value_losses)
r_theta = tf.exp(probs - old_probs)
p_opt_a = r_theta * advantage
p_opt_b = (
tf.clip_by_value(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon)
* advantage
)
self.policy_loss = -tf.reduce_mean(
tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.mask, 2)[1]
)
# For cleaner stats reporting
self.abs_policy_loss = tf.abs(self.policy_loss)
self.loss = (
self.policy_loss
+ 0.5 * self.value_loss
- decay_beta
* tf.reduce_mean(tf.dynamic_partition(entropy, self.mask, 2)[1])
)
def create_ppo_optimizer(self):
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.grads = self.optimizer.compute_gradients(self.loss)
self.update_batch = self.optimizer.minimize(self.loss)
正在加载...
取消
保存