浏览代码

Add flags for normalization and variable layers

/tag-0.2.0
Arthur Juliani 7 年前
当前提交
1bf46a85
共有 4 个文件被更改,包括 69 次插入42 次删除
  1. 22
      python/PPO.ipynb
  2. 9
      python/ppo.py
  3. 76
      python/ppo/models.py
  4. 4
      python/ppo/trainer.py

22
python/PPO.ipynb


"time_horizon = 2048 # How many steps to collect per agent before adding to buffer.\n",
"beta = 1e-3 # Strength of entropy regularization\n",
"num_epoch = 5 # Number of gradient descent steps per batch of experiences.\n",
"num_layers = 2 # Number of hidden layers between state/observation encoding and value/policy layers.\n",
"normalize = False\n",
"\n",
"### Logging dictionary for hyperparameters\n",
"hyperparameter_dict = {'max_steps':max_steps, 'run_path':run_path, 'env_name':env_name,\n",

{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"env = UnityEnvironment(file_name=env_name, curriculum=curriculum_file)\n",

"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],

"# Create the Tensorflow model graph\n",
"ppo_model = create_agent_model(env, lr=learning_rate,\n",
" h_size=hidden_units, epsilon=epsilon,\n",
" beta=beta, max_step=max_steps)\n",
" beta=beta, max_step=max_steps, \n",
" normalize=normalize, num_layers=num_layers)\n",
"\n",
"is_continuous = (env.brains[brain_name].action_space_type == \"continuous\")\n",
"use_observations = (env.brains[brain_name].number_observations > 0)\n",

" if env.global_done:\n",
" info = env.reset(train_mode=train_model, progress=get_progress())[brain_name]\n",
" # Decide and take an action\n",
" new_info = trainer.take_action(info, env, brain_name, steps)\n",
" new_info = trainer.take_action(info, env, brain_name, steps, normalize)\n",
" info = new_info\n",
" trainer.process_experiences(info, time_horizon, gamma, lambd)\n",
" if len(trainer.training_buffer['actions']) > buffer_size and train_model:\n",

"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 2",
"display_name": "Python 3",
"name": "python2"
"name": "python3"
"version": 2
"version": 3
"pygments_lexer": "ipython2",
"version": "2.7.10"
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,

9
python/ppo.py


--learning-rate=<rate> Model learning rate [default: 3e-4].
--load Whether to load the model or randomly initialize [default: False].
--max-steps=<n> Maximum number of steps to run environment [default: 1e6].
--normalize Whether to normalize the state input using running statistics [default: False].
--num-layers=<n> Number of hidden layers between state/observation and outputs [default: 2].
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo].
--save-freq=<n> Frequency at which to save model [default: 50000].
--summary-freq=<n> Frequency at which to save training statistics [default: 10000].

time_horizon = int(options['--time-horizon'])
beta = float(options['--beta'])
num_epoch = int(options['--num-epoch'])
num_layers = int(options['--num-layers'])
normalize = options['--normalize']
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file)
print(str(env))

# Create the Tensorflow model graph
ppo_model = create_agent_model(env, lr=learning_rate,
h_size=hidden_units, epsilon=epsilon,
beta=beta, max_step=max_steps)
beta=beta, max_step=max_steps,
normalize=normalize, num_layers=num_layers)
is_continuous = (env.brains[brain_name].action_space_type == "continuous")
use_observations = (env.brains[brain_name].number_observations > 0)

info = env.reset(train_mode=train_model, progress=get_progress())[brain_name]
trainer.reset_buffers(info, total=True)
# Decide and take an action
new_info = trainer.take_action(info, env, brain_name, steps)
new_info = trainer.take_action(info, env, brain_name, steps, normalize)
info = new_info
trainer.process_experiences(info, time_horizon, gamma, lambd)
if len(trainer.training_buffer['actions']) > buffer_size and train_model:

76
python/ppo/models.py


from unityagents import UnityEnvironmentException
def create_agent_model(env, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6):
def create_agent_model(env, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6, normalize=False, num_layers=2):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.

:return: a sub-class of PPOAgent tailored to the environment.
:param max_step: Total number of training steps.
"""
if num_layers < 1: num_layers = 1
return ContinuousControlModel(lr, brain, h_size, epsilon, max_step)
return ContinuousControlModel(lr, brain, h_size, epsilon, max_step, normalize, num_layers)
return DiscreteControlModel(lr, brain, h_size, epsilon, beta, max_step)
return DiscreteControlModel(lr, brain, h_size, epsilon, beta, max_step, normalize, num_layers)
def save_model(sess, saver, model_path="./", steps=0):

class PPOModel(object):
def __init__(self):
self.normalize = False
def create_global_steps(self):
"""Creates TF ops to track and increment global training step."""
self.global_step = tf.Variable(0, name="global_step", trainable=False, dtype=tf.int32)

self.new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward')
self.update_reward = tf.assign(self.last_reward, self.new_reward)
def create_visual_encoder(self, o_size_h, o_size_w, bw, h_size, num_streams, activation):
def create_visual_encoder(self, o_size_h, o_size_w, bw, h_size, num_streams, activation, num_layers):
"""
Builds a set of visual (CNN) encoders.
:param o_size_h: Height observation size.

use_bias=False, activation=activation)
self.conv2 = tf.layers.conv2d(self.conv1, 32, kernel_size=[4, 4], strides=[2, 2],
use_bias=False, activation=activation)
hidden = tf.layers.dense(c_layers.flatten(self.conv2), h_size, use_bias=False, activation=activation)
hidden = c_layers.flatten(self.conv2)
for j in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation)
def create_continuous_state_encoder(self, s_size, h_size, num_streams, activation):
def create_continuous_state_encoder(self, s_size, h_size, num_streams, activation, num_layers):
"""
Builds a set of hidden state encoders.
:param s_size: state input size.

"""
self.state_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32, name='state')
self.running_mean = tf.get_variable("running_mean", [s_size], trainable=False, dtype=tf.float32,
initializer=tf.zeros_initializer())
self.running_variance = tf.get_variable("running_variance", [s_size], trainable=False, dtype=tf.float32,
initializer=tf.ones_initializer())
if self.normalize:
self.running_mean = tf.get_variable("running_mean", [s_size], trainable=False, dtype=tf.float32,
initializer=tf.zeros_initializer())
self.running_variance = tf.get_variable("running_variance", [s_size], trainable=False, dtype=tf.float32,
initializer=tf.ones_initializer())
self.normalized_state = tf.clip_by_value((self.state_in - self.running_mean) / tf.sqrt(
self.running_variance / (tf.cast(self.global_step, tf.float32) + 1)), -5, 5, name="normalized_state")
self.normalized_state = tf.clip_by_value((self.state_in - self.running_mean) / tf.sqrt(
self.running_variance / (tf.cast(self.global_step, tf.float32) + 1)), -5, 5, name="normalized_state")
self.new_mean = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_mean')
self.new_variance = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_variance')
self.update_mean = tf.assign(self.running_mean, self.new_mean)
self.update_variance = tf.assign(self.running_variance, self.new_variance)
self.new_mean = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_mean')
self.new_variance = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_variance')
self.update_mean = tf.assign(self.running_mean, self.new_mean)
self.update_variance = tf.assign(self.running_variance, self.new_variance)
else:
self.normalized_state = self.state_in
hidden_1 = tf.layers.dense(self.normalized_state, h_size, use_bias=False, activation=activation)
hidden_2 = tf.layers.dense(hidden_1, h_size, use_bias=False, activation=activation)
streams.append(hidden_2)
hidden = self.normalized_state
for j in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation)
streams.append(hidden)
def create_discrete_state_encoder(self, s_size, h_size, num_streams, activation):
def create_discrete_state_encoder(self, s_size, h_size, num_streams, activation, num_layers):
"""
Builds a set of hidden state encoders from discrete state input.
:param s_size: state input size (discrete).

state_in = tf.reshape(self.state_in, [-1])
state_onehot = c_layers.one_hot_encoding(state_in, s_size)
streams = []
hidden = state_onehot
hidden = tf.layers.dense(state_onehot, h_size, use_bias=False, activation=activation)
for j in range(num_layers):
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation)
streams.append(hidden)
return streams

class ContinuousControlModel(PPOModel):
def __init__(self, lr, brain, h_size, epsilon, max_step):
def __init__(self, lr, brain, h_size, epsilon, max_step, normalize, num_layers):
super().__init__()
self.normalize = normalize
self.create_global_steps()
self.create_reward_encoder()

bw = brain.camera_resolutions[0]['blackAndWhite']
hidden_visual = self.create_visual_encoder(height_size, width_size, bw, h_size, 2, tf.nn.tanh)
hidden_visual = self.create_visual_encoder(height_size, width_size, bw, h_size, 2, tf.nn.tanh, num_layers)
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 2, tf.nn.tanh)
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 2, tf.nn.tanh, num_layers)
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 2, tf.nn.tanh)
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 2, tf.nn.tanh, num_layers)
if hidden_visual is None and hidden_state is None:
raise Exception("No valid network configuration possible. "

class DiscreteControlModel(PPOModel):
def __init__(self, lr, brain, h_size, epsilon, beta, max_step):
def __init__(self, lr, brain, h_size, epsilon, beta, max_step, normalize, num_layers):
super().__init__()
self.normalize = normalize
hidden_visual = self.create_visual_encoder(height_size, width_size, bw, h_size, 1, tf.nn.elu)[0]
hidden_visual = self.create_visual_encoder(height_size, width_size, bw, h_size, 1, tf.nn.elu, num_layers)[0]
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 1, tf.nn.elu)[0]
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 1, tf.nn.elu, num_layers)[0]
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 1, tf.nn.elu)[0]
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 1, tf.nn.elu, num_layers)[0]
if hidden_visual is None and hidden_state is None:
raise Exception("No valid network configuration possible. "

4
python/ppo/trainer.py


new_variance = var + (current_x - new_mean) * (current_x - mean)
return new_mean, new_variance
def take_action(self, info, env, brain_name, steps):
def take_action(self, info, env, brain_name, steps, normalize):
"""
Decides actions given state/observation information, and takes them in environment.
:param info: Current BrainInfo from environment.

feed_dict[self.model.observation_in] = np.vstack(info.observations)
if self.use_states:
feed_dict[self.model.state_in] = info.states
if self.is_training and env.brains[brain_name].state_space_type == "continuous" and self.use_states:
if self.is_training and env.brains[brain_name].state_space_type == "continuous" and self.use_states and normalize:
new_mean, new_variance = self.running_average(info.states, steps, self.model.running_mean,
self.model.running_variance)
feed_dict[self.model.new_mean] = new_mean

正在加载...
取消
保存