浏览代码

Rename function

/develop-generalizationTraining-TrainerController
Arthur Juliani 6 年前
当前提交
b46b8708
共有 2 个文件被更改,包括 13 次插入13 次删除
  1. 10
      python/unitytrainers/models.py
  2. 16
      python/unitytrainers/ppo/models.py

10
python/unitytrainers/models.py


return update_mean, update_variance
@staticmethod
def create_continuous_observation_encoder(observation_input, h_size, activation, num_layers, scope, reuse):
def create_vector_observation_encoder(observation_input, h_size, activation, num_layers, scope, reuse):
"""
Builds a set of hidden state encoders.
:param reuse: Whether to re-use the weights within the same scope.

hidden = c_layers.flatten(conv2)
with tf.variable_scope(scope+'/'+'flat_encoding'):
hidden_flat = self.create_continuous_observation_encoder(hidden, h_size, activation,
num_layers, scope, reuse)
hidden_flat = self.create_vector_observation_encoder(hidden, h_size, activation,
num_layers, scope, reuse)
return hidden_flat
def create_observation_streams(self, num_streams, h_size, num_layers):

visual_encoders.append(encoded_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
if brain.vector_observation_space_size > 0:
hidden_state = self.create_continuous_observation_encoder(vector_observation_input,
h_size, activation_fn, num_layers,
hidden_state = self.create_vector_observation_encoder(vector_observation_input,
h_size, activation_fn, num_layers,
"main_graph_{}".format(i), False)
if hidden_state is not None and hidden_visual is not None:
final_hidden = tf.concat([hidden_visual, hidden_state], axis=1)

16
python/unitytrainers/ppo/models.py


self.next_vector_in = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32,
name='next_vector_observation')
encoded_vector_obs = self.create_continuous_observation_encoder(self.vector_in,
self.curiosity_enc_size,
self.swish, 2, "vector_obs_encoder",
False)
encoded_next_vector_obs = self.create_continuous_observation_encoder(self.next_vector_in,
self.curiosity_enc_size,
self.swish, 2,
encoded_vector_obs = self.create_vector_observation_encoder(self.vector_in,
self.curiosity_enc_size,
self.swish, 2, "vector_obs_encoder",
False)
encoded_next_vector_obs = self.create_vector_observation_encoder(self.next_vector_in,
self.curiosity_enc_size,
self.swish, 2,
True)
True)
encoded_state_list.append(encoded_vector_obs)
encoded_next_state_list.append(encoded_next_vector_obs)

正在加载...
取消
保存