|
|
|
|
|
|
num_layers=2, |
|
|
|
stream_names=None, |
|
|
|
vis_encode_type=EncoderType.SIMPLE, |
|
|
|
separate_train=False |
|
|
|
separate_train=False, |
|
|
|
): |
|
|
|
super().__init__( |
|
|
|
policy, |
|
|
|
|
|
|
shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in" |
|
|
|
) |
|
|
|
self.value_memory_in = self.memory_in |
|
|
|
hidden_streams = ModelUtils.create_observation_streams( |
|
|
|
self.visual_in, |
|
|
|
self.processed_vector_in, |
|
|
|
1, |
|
|
|
self.h_size, |
|
|
|
0, |
|
|
|
vis_encode_type=vis_encode_type, |
|
|
|
stream_scopes=["critic/value/"], |
|
|
|
) |
|
|
|
# hidden_streams = ModelUtils.create_observation_streams( |
|
|
|
# self.visual_in, |
|
|
|
# self.processed_vector_in, |
|
|
|
# 1, |
|
|
|
# self.h_size, |
|
|
|
# 1, |
|
|
|
# vis_encode_type=vis_encode_type, |
|
|
|
# stream_scopes=["critic/value/"], |
|
|
|
# # reuse=True |
|
|
|
# ) |
|
|
|
hidden_critic = self._create_observation_in(self.visual_in, |
|
|
|
self.processed_vector_in, |
|
|
|
vis_encode_type) |
|
|
|
self._create_cc_critic(hidden_streams[0], hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|
|
|
self._create_cc_critic(hidden_critic, hidden_critic, TARGET_SCOPE, create_qs=False) |
|
|
|
# self._create_cc_critic(self.policy.targ_encoder, self.policy.targ_encoder, TARGET_SCOPE, create_qs=False) |
|
|
|
# self._create_cc_critic(self.policy.targ_encoder, TARGET_SCOPE, create_qs=False) |
|
|
|
else: |
|
|
|
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|
|
|
|
|
|
self.value_memory_out, axis=1 |
|
|
|
) # Needed for Barracuda to work |
|
|
|
# self.critic_vars += self.get_vars("target_enc") |
|
|
|
# self.value_vars += self.get_vars("target_enc") |
|
|
|
|
|
|
|
def copy_normalization(self, mean, variance, steps): |
|
|
|
""" |
|
|
|
|
|
|
update_norm_step = tf.assign(self.normalization_steps, steps) |
|
|
|
return tf.group([update_mean, update_variance, update_norm_step]) |
|
|
|
|
|
|
|
def _create_observation_in(self, visual_in, vector_in, vis_encode_type): |
|
|
|
""" |
|
|
|
Creates the observation inputs, and a CNN if needed, |
|
|
|
:param vis_encode_type: Type of CNN encoder. |
|
|
|
:param share_ac_cnn: Whether or not to share the actor and critic CNNs. |
|
|
|
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|
|
|
once and thrown away. |
|
|
|
""" |
|
|
|
hidden = self.policy._create_encoder_general( |
|
|
|
visual_in, |
|
|
|
vector_in, |
|
|
|
self.h_size, |
|
|
|
self.policy.feature_size, |
|
|
|
1, |
|
|
|
vis_encode_type, |
|
|
|
scope="target_enc", #"target_network/critic/value", |
|
|
|
reuse=True |
|
|
|
) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
class SACTransferPolicyNetwork(SACTransferNetwork): |
|
|
|
""" |
|
|
|
|
|
|
self.policy.output = self.policy.output |
|
|
|
# Use the sequence length of the policy |
|
|
|
self.sequence_length_ph = self.policy.sequence_length_ph |
|
|
|
# self.hidden = hidden_critic |
|
|
|
|
|
|
|
if separate_train: |
|
|
|
hidden = tf.stop_gradient(self.policy.encoder) |
|
|
|
|
|
|
self._create_cc_critic(hidden_critic, hidden_critic, POLICY_SCOPE) |
|
|
|
# self._create_cc_critic(self.policy.encoder, self.policy.encoder, POLICY_SCOPE) |
|
|
|
# self._create_cc_critic(hidden, POLICY_SCOPE) |
|
|
|
else: |
|
|
|
self._create_dc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
|
|
|
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out] |
|
|
|
self.memory_out = tf.concat(mem_outs, axis=1) |
|
|
|
# self.critic_vars += self.get_vars("encoding") |
|
|
|
# self.value_vars += self.get_vars("encoding") |
|
|
|
|
|
|
|
def _create_memory_ins(self, m_size): |
|
|
|
""" |
|
|
|
|
|
|
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|
|
|
once and thrown away. |
|
|
|
""" |
|
|
|
with tf.variable_scope(POLICY_SCOPE): |
|
|
|
hidden_streams = ModelUtils.create_observation_streams( |
|
|
|
self.policy.visual_in, |
|
|
|
self.policy.processed_vector_in, |
|
|
|
1, |
|
|
|
self.h_size, |
|
|
|
0, |
|
|
|
vis_encode_type=vis_encode_type, |
|
|
|
stream_scopes=["critic/value/"], |
|
|
|
) |
|
|
|
hidden_critic = hidden_streams[0] |
|
|
|
return hidden_critic |
|
|
|
# with tf.variable_scope(POLICY_SCOPE): |
|
|
|
# hidden_streams = ModelUtils.create_observation_streams( |
|
|
|
# self.policy.visual_in, |
|
|
|
# self.policy.processed_vector_in, |
|
|
|
# 1, |
|
|
|
# self.h_size, |
|
|
|
# 1, |
|
|
|
# vis_encode_type=vis_encode_type, |
|
|
|
# stream_scopes=["critic/value/"], |
|
|
|
# ) |
|
|
|
# hidden_critic = hidden_streams[0] |
|
|
|
# return hidden_critic |
|
|
|
|
|
|
|
hidden = self.policy._create_encoder_general( |
|
|
|
self.policy.visual_in, |
|
|
|
self.policy.processed_vector_in, |
|
|
|
self.h_size, |
|
|
|
self.policy.feature_size, |
|
|
|
1, |
|
|
|
vis_encode_type, |
|
|
|
scope="encoding", #"critic/value", |
|
|
|
reuse=True |
|
|
|
) |
|
|
|
return hidden |