|
|
|
|
|
|
self.new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward') |
|
|
|
self.update_reward = tf.assign(self.last_reward, self.new_reward) |
|
|
|
|
|
|
|
def create_recurrent_encoder(self, s_size, input_state): |
|
|
|
def create_recurrent_encoder(self, input_state, memory_in, name = 'lstm'): |
|
|
|
self.lstm_input_state = tf.reshape(input_state, shape = [-1, self.sequence_length, s_size]) |
|
|
|
self.memory_in = tf.placeholder(shape=[None, self.m_size],dtype=tf.float32, name='recurrent_in') |
|
|
|
_half_point = int(self.m_size/2) |
|
|
|
rnn_cell = tf.contrib.rnn.BasicLSTMCell(_half_point) |
|
|
|
lstm_state_in = tf.contrib.rnn.LSTMStateTuple(self.memory_in[:,:_half_point], self.memory_in[:,_half_point:]) |
|
|
|
self.recurrent_state, self.lstm_state_out = tf.nn.dynamic_rnn(rnn_cell, self.lstm_input_state, |
|
|
|
initial_state=lstm_state_in, |
|
|
|
time_major=False, |
|
|
|
dtype=tf.float32) |
|
|
|
self.memory_out = tf.concat([self.lstm_state_out.c,self.lstm_state_out.h], axis = 1) |
|
|
|
self.memory_out = tf.identity(self.memory_out, name = 'recurrent_out') |
|
|
|
recurrent_state = tf.reshape(self.recurrent_state, shape = [-1, _half_point]) |
|
|
|
return recurrent_state |
|
|
|
s_size = input_state.get_shape().as_list()[1] |
|
|
|
m_size = memory_in.get_shape().as_list()[1] |
|
|
|
lstm_input_state = tf.reshape(input_state, shape = [-1, self.sequence_length, s_size]) |
|
|
|
_half_point = int(m_size/2) |
|
|
|
with tf.variable_scope(name): |
|
|
|
rnn_cell = tf.contrib.rnn.BasicLSTMCell(_half_point) |
|
|
|
lstm_state_in = tf.contrib.rnn.LSTMStateTuple(memory_in[:,:_half_point], memory_in[:,_half_point:]) |
|
|
|
recurrent_state, lstm_state_out = tf.nn.dynamic_rnn(rnn_cell, lstm_input_state, |
|
|
|
initial_state=lstm_state_in, |
|
|
|
time_major=False, |
|
|
|
dtype=tf.float32) |
|
|
|
recurrent_state = tf.reshape(recurrent_state, shape = [-1, _half_point]) |
|
|
|
return recurrent_state, tf.concat([lstm_state_out.c, lstm_state_out.h], axis = 1) |
|
|
|
|
|
|
|
def create_visual_encoder(self, o_size_h, o_size_w, bw, h_size, num_streams, activation, num_layers): |
|
|
|
""" |
|
|
|
|
|
|
use_bias=False, activation=activation) |
|
|
|
self.conv2 = tf.layers.conv2d(self.conv1, 32, kernel_size=[4, 4], strides=[2, 2], |
|
|
|
use_bias=False, activation=activation) |
|
|
|
|
|
|
|
if self.use_recurrent: |
|
|
|
_rec_input = c_layers.flatten(self.conv2) |
|
|
|
hidden = self.create_recurrent_encoder(_rec_input.get_shape().as_list()[1], _rec_input) |
|
|
|
else: |
|
|
|
hidden = c_layers.flatten(self.conv2) |
|
|
|
hidden = c_layers.flatten(self.conv2) |
|
|
|
|
|
|
|
for j in range(num_layers): |
|
|
|
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation) |
|
|
|
|
|
|
else: |
|
|
|
self.normalized_state = self.state_in |
|
|
|
|
|
|
|
if self.use_recurrent: |
|
|
|
self.recurrent_state = self.create_recurrent_encoder(s_size, self.normalized_state) |
|
|
|
else: |
|
|
|
self.recurrent_state = self.normalized_state |
|
|
|
|
|
|
|
hidden = self.recurrent_state |
|
|
|
hidden = self.normalized_state |
|
|
|
for j in range(num_layers): |
|
|
|
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation) |
|
|
|
streams.append(hidden) |
|
|
|
|
|
|
state_in = tf.reshape(self.state_in, [-1]) |
|
|
|
state_onehot = c_layers.one_hot_encoding(state_in, s_size) |
|
|
|
streams = [] |
|
|
|
if self.use_recurrent: |
|
|
|
hidden = self.create_recurrent_encoder(s_size, state_onehot) |
|
|
|
else: |
|
|
|
hidden = state_onehot |
|
|
|
|
|
|
|
hidden = state_onehot |
|
|
|
for i in range(num_streams): |
|
|
|
for j in range(num_layers): |
|
|
|
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation) |
|
|
|
|
|
|
hidden_policy = tf.concat([hidden_visual[0], hidden_state[0]], axis=1) |
|
|
|
hidden_value = tf.concat([hidden_visual[1], hidden_state[1]], axis=1) |
|
|
|
|
|
|
|
if self.use_recurrent: |
|
|
|
self.memory_in = tf.placeholder(shape=[None, self.m_size],dtype=tf.float32, name='recurrent_in') |
|
|
|
_half_point = int(self.m_size/2) |
|
|
|
hidden_policy , memory_policy_out = self.create_recurrent_encoder( hidden_policy, self.memory_in[:, :_half_point ], name = 'lstm_policy') |
|
|
|
hidden_value , memory_value_out = self.create_recurrent_encoder( hidden_value, self.memory_in[:, _half_point: ], name = 'lstm_value') |
|
|
|
self.memory_out = tf.concat([memory_policy_out, memory_value_out], axis=1, name = 'recurrent_out') |
|
|
|
|
|
|
|
self.mu = tf.layers.dense(hidden_policy, a_size, activation=None, use_bias=False, |
|
|
|
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) |
|
|
|
|
|
|
raise Exception("No valid network configuration possible. " |
|
|
|
"There are no states or observations in this brain") |
|
|
|
elif hidden_visual is not None and hidden_state is None: |
|
|
|
hidden = hidden_visual |
|
|
|
hidden = hidden_visual[0] |
|
|
|
|
|
|
|
if self.use_recurrent: |
|
|
|
self.memory_in = tf.placeholder(shape=[None, self.m_size],dtype=tf.float32, name='recurrent_in') |
|
|
|
hidden, self.memory_out = self.create_recurrent_encoder( hidden, self.memory_in) |
|
|
|
self.memory_out = tf.identity(self.memory_out, name = 'recurrent_out') |
|
|
|
|
|
|
|
a_size = brain.action_space_size |
|
|
|
|
|
|
|