|
|
|
|
|
|
self.value_heads[name] = value |
|
|
|
self.value = tf.reduce_mean(list(self.value_heads.values()), 0) |
|
|
|
|
|
|
|
def _create_cc_critic(self, hidden_value, scope, create_qs=True): |
|
|
|
def _create_cc_critic(self, hidden_value, scope, create_qs=True, a_layers=-1, a_features=16): |
|
|
|
""" |
|
|
|
Creates just the critic network |
|
|
|
""" |
|
|
|
|
|
|
name="external_action_in", |
|
|
|
) |
|
|
|
self.value_vars = self.get_vars(self.join_scopes(scope, "value")) |
|
|
|
|
|
|
|
# external_action_encoder = self.policy._create_action_encoder( |
|
|
|
# self.external_action_in, |
|
|
|
# self.h_size, |
|
|
|
# a_features, |
|
|
|
# a_layers, |
|
|
|
# reuse=True |
|
|
|
# ) |
|
|
|
|
|
|
|
# output_action_encoder = self.policy._create_action_encoder( |
|
|
|
# self.policy.output, |
|
|
|
# self.h_size, |
|
|
|
# a_features, |
|
|
|
# a_layers, |
|
|
|
# reuse=True |
|
|
|
# ) |
|
|
|
# hidden_q = tf.concat([hidden_value, external_action_encoder], axis=-1) |
|
|
|
# hidden_qp = tf.concat([hidden_value, output_action_encoder], axis=-1) |
|
|
|
|
|
|
|
self.num_layers, |
|
|
|
self.num_layers+2, |
|
|
|
self.h_size, |
|
|
|
self.join_scopes(scope, "q"), |
|
|
|
) |
|
|
|
|
|
|
self.num_layers, |
|
|
|
self.num_layers+2, |
|
|
|
self.h_size, |
|
|
|
self.join_scopes(scope, "q"), |
|
|
|
reuse=True, |
|
|
|
|
|
|
use_recurrent=False, |
|
|
|
encoder_layers=0, |
|
|
|
num_layers=2, |
|
|
|
action_layers=-1, |
|
|
|
action_features=16, |
|
|
|
stream_names=None, |
|
|
|
vis_encode_type=EncoderType.SIMPLE, |
|
|
|
separate_train=False, |
|
|
|
|
|
|
use_recurrent=False, |
|
|
|
encoder_layers=0, |
|
|
|
num_layers=2, |
|
|
|
action_layers=-1, |
|
|
|
action_features=16, |
|
|
|
stream_names=None, |
|
|
|
vis_encode_type=EncoderType.SIMPLE, |
|
|
|
separate_train=False, |
|
|
|
|
|
|
hidden_critic = tf.stop_gradient(hidden_critic) |
|
|
|
|
|
|
|
if self.policy.use_continuous_act: |
|
|
|
self._create_cc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
self._create_cc_critic(hidden_critic, POLICY_SCOPE, a_layers=action_layers, a_features=action_features) |
|
|
|
else: |
|
|
|
self._create_dc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
# self._create_dc_critic(hidden, POLICY_SCOPE) |
|
|
|