|
|
|
|
|
|
|
|
|
|
return q1_heads, q2_heads, q1, q2 |
|
|
|
|
|
|
|
def _create_encoder( |
|
|
|
self, |
|
|
|
visual_in, |
|
|
|
vector_in, |
|
|
|
vis_encode_type, |
|
|
|
encoder_layers, |
|
|
|
scope, |
|
|
|
reuse=False |
|
|
|
): |
|
|
|
""" |
|
|
|
Creates the observation inputs, and a CNN if needed, |
|
|
|
:param vis_encode_type: Type of CNN encoder. |
|
|
|
:param share_ac_cnn: Whether or not to share the actor and critic CNNs. |
|
|
|
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|
|
|
once and thrown away. |
|
|
|
""" |
|
|
|
hidden = self.policy._create_encoder_general( |
|
|
|
visual_in, |
|
|
|
vector_in, |
|
|
|
self.h_size, |
|
|
|
self.policy.feature_size, |
|
|
|
encoder_layers, |
|
|
|
vis_encode_type, |
|
|
|
scope=scope, |
|
|
|
reuse=reuse |
|
|
|
) |
|
|
|
|
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
class SACTransferTargetNetwork(SACTransferNetwork): |
|
|
|
""" |
|
|
|
|
|
|
h_size=128, |
|
|
|
normalize=False, |
|
|
|
use_recurrent=False, |
|
|
|
encoder_layers=0, |
|
|
|
num_layers=2, |
|
|
|
stream_names=None, |
|
|
|
vis_encode_type=EncoderType.SIMPLE, |
|
|
|
|
|
|
shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in" |
|
|
|
) |
|
|
|
self.value_memory_in = self.memory_in |
|
|
|
|
|
|
|
hidden_critic = self._create_observation_in( |
|
|
|
|
|
|
|
hidden_critic = self._create_encoder( |
|
|
|
vis_encode_type |
|
|
|
vis_encode_type, |
|
|
|
encoder_layers=encoder_layers, |
|
|
|
scope="target_enc", |
|
|
|
reuse=True |
|
|
|
if separate_train: |
|
|
|
hidden_critic = tf.stop_gradient(hidden_critic) |
|
|
|
# self._create_cc_critic(self.policy.targ_encoder, self.policy.targ_encoder, TARGET_SCOPE, create_qs=False) |
|
|
|
# self._create_cc_critic(self.policy.targ_encoder, TARGET_SCOPE, create_qs=False) |
|
|
|
# self._create_dc_critic(self.policy.targ_encoder, TARGET_SCOPE, create_qs=False) |
|
|
|
# self.critic_vars += self.get_vars("target_enc") |
|
|
|
# self.value_vars += self.get_vars("target_enc") |
|
|
|
|
|
|
|
def copy_normalization(self, mean, variance, steps): |
|
|
|
""" |
|
|
|
|
|
|
update_norm_step = tf.assign(self.normalization_steps, steps) |
|
|
|
return tf.group([update_mean, update_variance, update_norm_step]) |
|
|
|
|
|
|
|
def _create_observation_in(self, visual_in, vector_in, vis_encode_type): |
|
|
|
""" |
|
|
|
Creates the observation inputs, and a CNN if needed, |
|
|
|
:param vis_encode_type: Type of CNN encoder. |
|
|
|
:param share_ac_cnn: Whether or not to share the actor and critic CNNs. |
|
|
|
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|
|
|
once and thrown away. |
|
|
|
""" |
|
|
|
hidden = self.policy._create_encoder_general( |
|
|
|
visual_in, |
|
|
|
vector_in, |
|
|
|
self.h_size, |
|
|
|
self.policy.feature_size, |
|
|
|
1, |
|
|
|
vis_encode_type, |
|
|
|
scope="target_enc", #"target_network/critic/value", |
|
|
|
reuse=True |
|
|
|
) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
class SACTransferPolicyNetwork(SACTransferNetwork): |
|
|
|
""" |
|
|
|
|
|
|
h_size=128, |
|
|
|
normalize=False, |
|
|
|
use_recurrent=False, |
|
|
|
encoder_layers=0, |
|
|
|
separate_train=False |
|
|
|
separate_train=False, |
|
|
|
): |
|
|
|
super().__init__( |
|
|
|
policy, |
|
|
|
|
|
|
if self.policy.use_recurrent: |
|
|
|
self._create_memory_ins(m_size) |
|
|
|
|
|
|
|
hidden_critic = self._create_observation_in(vis_encode_type) |
|
|
|
# self.hidden = hidden_critic |
|
|
|
hidden_critic = self._create_encoder( |
|
|
|
self.visual_in, |
|
|
|
self.processed_vector_in, |
|
|
|
vis_encode_type, |
|
|
|
encoder_layers=encoder_layers, |
|
|
|
scope="encoding", |
|
|
|
reuse=True |
|
|
|
) |
|
|
|
hidden = tf.stop_gradient(self.policy.encoder) |
|
|
|
else: |
|
|
|
hidden = self.policy.encoder |
|
|
|
hidden_critic = tf.stop_gradient(hidden_critic) |
|
|
|
|
|
|
|
# self._create_cc_critic(self.policy.encoder, self.policy.encoder, POLICY_SCOPE) |
|
|
|
# self._create_cc_critic(hidden, POLICY_SCOPE) |
|
|
|
else: |
|
|
|
self._create_dc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
# self._create_dc_critic(hidden, POLICY_SCOPE) |
|
|
|
|
|
|
self.q1_memory_in = mem_ins[1] |
|
|
|
self.q2_memory_in = mem_ins[2] |
|
|
|
|
|
|
|
def _create_observation_in(self, vis_encode_type): |
|
|
|
""" |
|
|
|
Creates the observation inputs, and a CNN if needed, |
|
|
|
:param vis_encode_type: Type of CNN encoder. |
|
|
|
:param share_ac_cnn: Whether or not to share the actor and critic CNNs. |
|
|
|
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|
|
|
once and thrown away. |
|
|
|
""" |
|
|
|
hidden = self.policy._create_encoder_general( |
|
|
|
self.policy.visual_in, |
|
|
|
self.policy.processed_vector_in, |
|
|
|
self.h_size, |
|
|
|
self.policy.feature_size, |
|
|
|
1, |
|
|
|
vis_encode_type, |
|
|
|
scope="encoding", #"critic/value", |
|
|
|
reuse=True |
|
|
|
) |
|
|
|
return hidden |