|
|
|
|
|
|
self.value_heads[name] = value |
|
|
|
self.value = tf.reduce_mean(list(self.value_heads.values()), 0) |
|
|
|
|
|
|
|
def _create_cc_critic(self, encoder, hidden_value, scope, create_qs=True): |
|
|
|
def _create_cc_critic(self, hidden_value, scope, create_qs=True): |
|
|
|
""" |
|
|
|
Creates just the critic network |
|
|
|
""" |
|
|
|
|
|
|
encoder, |
|
|
|
hidden_value, |
|
|
|
self.num_layers, |
|
|
|
self.h_size, |
|
|
|
self.join_scopes(scope, "value"), |
|
|
|
|
|
|
shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in" |
|
|
|
) |
|
|
|
self.value_memory_in = self.memory_in |
|
|
|
# hidden_streams = ModelUtils.create_observation_streams( |
|
|
|
# self.visual_in, |
|
|
|
# self.processed_vector_in, |
|
|
|
# 1, |
|
|
|
# self.h_size, |
|
|
|
# 1, |
|
|
|
# vis_encode_type=vis_encode_type, |
|
|
|
# stream_scopes=["critic/value/"], |
|
|
|
# # reuse=True |
|
|
|
# ) |
|
|
|
hidden_critic = self._create_observation_in(self.visual_in, |
|
|
|
|
|
|
|
hidden_critic = self._create_observation_in( |
|
|
|
self.visual_in, |
|
|
|
vis_encode_type) |
|
|
|
vis_encode_type |
|
|
|
) |
|
|
|
self._create_cc_critic(hidden_critic, hidden_critic, TARGET_SCOPE, create_qs=False) |
|
|
|
self._create_cc_critic(hidden_critic, TARGET_SCOPE, create_qs=False) |
|
|
|
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|
|
|
self._create_dc_critic(hidden_critic, TARGET_SCOPE, create_qs=False) |
|
|
|
# self._create_dc_critic(self.policy.targ_encoder, TARGET_SCOPE, create_qs=False) |
|
|
|
if self.use_recurrent: |
|
|
|
self.memory_out = tf.concat( |
|
|
|
|
|
|
else: |
|
|
|
hidden = self.policy.encoder |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
self._create_cc_critic(hidden_critic, hidden_critic, POLICY_SCOPE) |
|
|
|
self._create_cc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
# self._create_cc_critic(self.policy.encoder, self.policy.encoder, POLICY_SCOPE) |
|
|
|
# self._create_cc_critic(hidden, POLICY_SCOPE) |
|
|
|
else: |
|
|
|
|
|
|
if self.use_recurrent: |
|
|
|
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out] |
|
|
|
self.memory_out = tf.concat(mem_outs, axis=1) |
|
|
|
# self.critic_vars += self.get_vars("encoding") |
|
|
|
# self.value_vars += self.get_vars("encoding") |
|
|
|
|
|
|
|
def _create_memory_ins(self, m_size): |
|
|
|
""" |
|
|
|
|
|
|
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|
|
|
once and thrown away. |
|
|
|
""" |
|
|
|
# with tf.variable_scope(POLICY_SCOPE): |
|
|
|
# hidden_streams = ModelUtils.create_observation_streams( |
|
|
|
# self.policy.visual_in, |
|
|
|
# self.policy.processed_vector_in, |
|
|
|
# 1, |
|
|
|
# self.h_size, |
|
|
|
# 1, |
|
|
|
# vis_encode_type=vis_encode_type, |
|
|
|
# stream_scopes=["critic/value/"], |
|
|
|
# ) |
|
|
|
# hidden_critic = hidden_streams[0] |
|
|
|
# return hidden_critic |
|
|
|
|
|
|
|
hidden = self.policy._create_encoder_general( |
|
|
|
self.policy.visual_in, |
|
|
|
self.policy.processed_vector_in, |
|
|
|