浏览代码

Remove PPO model

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
6baaf980
共有 1 个文件被更改,包括 0 次插入263 次删除
  1. 263
      ml-agents/mlagents/trainers/ppo/models.py

263
ml-agents/mlagents/trainers/ppo/models.py


import logging
from typing import Optional
import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.models import LearningModel, EncoderType, LearningRateSchedule
logger = logging.getLogger("mlagents.trainers")
class PPOModel(LearningModel):
def __init__(
self,
brain,
lr=1e-4,
lr_schedule=LearningRateSchedule.LINEAR,
h_size=128,
epsilon=0.2,
beta=1e-3,
max_step=5e6,
normalize=False,
use_recurrent=False,
num_layers=2,
m_size=None,
seed=0,
stream_names=None,
vis_encode_type=EncoderType.SIMPLE,
):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.
:param brain: brain parameters used to generate specific network graph.
:param lr: Learning rate.
:param lr_schedule: Learning rate decay schedule.
:param h_size: Size of hidden layers
:param epsilon: Value for policy-divergence threshold.
:param beta: Strength of entropy regularization.
:param max_step: Total number of training steps.
:param normalize: Whether to normalize vector observation input.
:param use_recurrent: Whether to use an LSTM layer in the network.
:param num_layers Number of hidden layers between encoded input and policy & value layers
:param m_size: Size of brain memory.
:param seed: Seed to use for initialization of model.
:param stream_names: List of names of value streams. Usually, a list of the Reward Signals being used.
:return: a sub-class of PPOAgent tailored to the environment.
"""
LearningModel.__init__(
self, m_size, normalize, use_recurrent, brain, seed, stream_names
)
self.optimizer: Optional[tf.train.AdamOptimizer] = None
self.grads = None
self.update_batch: Optional[tf.Operation] = None
if num_layers < 1:
num_layers = 1
if brain.vector_action_space_type == "continuous":
self.create_cc_actor(h_size, num_layers, vis_encode_type)
else:
self.create_dc_actor(h_size, num_layers, vis_encode_type)
def create_cc_actor(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_stream = LearningModel.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
stream_scopes=["policy"],
)[0]
if self.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_stream,
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
self.memory_out = memory_policy_out
else:
hidden_policy = hidden_stream
mu = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
kernel_initializer=LearningModel.scaled_init(0.01),
reuse=tf.AUTO_REUSE,
)
self.log_sigma_sq = tf.get_variable(
"log_sigma_squared",
[self.act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
sigma_sq = tf.exp(self.log_sigma_sq)
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"
)
# Clip and scale output to ensure actions are always within [-1, 1] range.
self.output_pre = mu + tf.sqrt(sigma_sq) * self.epsilon
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")
self.selected_actions = tf.stop_gradient(output_post)
# Compute probability of model output.
all_probs = (
-0.5 * tf.square(tf.stop_gradient(self.output_pre) - mu) / sigma_sq
- 0.5 * tf.log(2.0 * np.pi)
- 0.5 * self.log_sigma_sq
)
self.all_log_probs = tf.identity(all_probs, name="action_probs")
single_dim_entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + self.log_sigma_sq
)
# Make entropy the right shape
self.entropy = tf.ones_like(tf.reshape(mu[:, 0], [-1])) * single_dim_entropy
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
self.log_probs = tf.reduce_sum(
(tf.identity(self.all_log_probs)), axis=1, keepdims=True
)
def create_dc_actor(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
"""
hidden_stream = self.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
stream_scopes=["policy"],
)[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
hidden_policy = tf.concat([hidden_stream, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_policy,
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
self.memory_out = memory_policy_out
else:
hidden_policy = hidden_stream
policy_branches = []
for size in self.act_size:
policy_branches.append(
tf.layers.dense(
hidden_policy,
size,
activation=None,
use_bias=False,
kernel_initializer=LearningModel.scaled_init(0.01),
)
)
self.all_log_probs = tf.concat(policy_branches, axis=1, name="action_probs")
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
output, _, normalized_logits = self.create_discrete_action_masking_layer(
self.all_log_probs, self.action_masks, self.act_size
)
self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"
)
self.action_oh = tf.concat(
[
tf.one_hot(self.action_holder[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
self.selected_actions = tf.stop_gradient(self.action_oh)
action_idx = [0] + list(np.cumsum(self.act_size))
self.entropy = tf.reduce_sum(
(
tf.stack(
[
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.nn.softmax(
self.all_log_probs[:, action_idx[i] : action_idx[i + 1]]
),
logits=self.all_log_probs[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
)
self.log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
正在加载...
取消
保存