|
|
|
|
|
|
self.value_heads[name] = value |
|
|
|
self.value = tf.reduce_mean(list(self.value_heads.values()), 0) |
|
|
|
|
|
|
|
def create_cc_critic(self, hidden_value, scope, create_qs=True): |
|
|
|
def _create_cc_critic(self, hidden_value, scope, create_qs=True): |
|
|
|
""" |
|
|
|
Creates just the critic network |
|
|
|
""" |
|
|
|
|
|
|
self.q_vars = self.get_vars(self.join_scopes(scope, "q")) |
|
|
|
self.critic_vars = self.get_vars(scope) |
|
|
|
|
|
|
|
def create_dc_critic(self, hidden_value, scope, create_qs=True): |
|
|
|
def _create_dc_critic(self, hidden_value, scope, create_qs=True): |
|
|
|
""" |
|
|
|
Creates just the critic network |
|
|
|
""" |
|
|
|
|
|
|
stream_scopes=["critic/value/"], |
|
|
|
) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
self.create_cc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|
|
|
self._create_cc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|
|
|
self.create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|
|
|
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|
|
|
if self.use_recurrent: |
|
|
|
self.memory_out = tf.concat( |
|
|
|
self.value_memory_out, axis=1 |
|
|
|
|
|
|
self.sequence_length_ph = self.policy.sequence_length_ph |
|
|
|
|
|
|
|
if self.policy.use_continuous_act: |
|
|
|
self.create_cc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
self._create_cc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
self.create_dc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
self._create_dc_critic(hidden_critic, POLICY_SCOPE) |
|
|
|
|
|
|
|
if self.use_recurrent: |
|
|
|
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out] |
|
|
|