浏览代码

Make create critic methods private

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
c735e722
共有 2 个文件被更改,包括 10 次插入10 次删除
  1. 8
      ml-agents/mlagents/trainers/ppo/optimizer.py
  2. 12
      ml-agents/mlagents/trainers/sac/network.py

8
ml-agents/mlagents/trainers/ppo/optimizer.py


if num_layers < 1:
num_layers = 1
if policy.use_continuous_act:
self.create_cc_critic(h_size, num_layers, vis_encode_type)
self._create_cc_critic(h_size, num_layers, vis_encode_type)
self.create_dc_critic(h_size, num_layers, vis_encode_type)
self._create_dc_critic(h_size, num_layers, vis_encode_type)
self.learning_rate = LearningModel.create_learning_rate(
lr_schedule, lr, self.policy.global_step, int(max_step)

self.policy.initialize_or_load()
def create_cc_critic(
def _create_cc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""

(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True
)
def create_dc_critic(
def _create_dc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
"""

12
ml-agents/mlagents/trainers/sac/network.py


self.value_heads[name] = value
self.value = tf.reduce_mean(list(self.value_heads.values()), 0)
def create_cc_critic(self, hidden_value, scope, create_qs=True):
def _create_cc_critic(self, hidden_value, scope, create_qs=True):
"""
Creates just the critic network
"""

self.q_vars = self.get_vars(self.join_scopes(scope, "q"))
self.critic_vars = self.get_vars(scope)
def create_dc_critic(self, hidden_value, scope, create_qs=True):
def _create_dc_critic(self, hidden_value, scope, create_qs=True):
"""
Creates just the critic network
"""

stream_scopes=["critic/value/"],
)
if self.policy.use_continuous_act:
self.create_cc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)
self._create_cc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)
self.create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)
if self.use_recurrent:
self.memory_out = tf.concat(
self.value_memory_out, axis=1

self.sequence_length_ph = self.policy.sequence_length_ph
if self.policy.use_continuous_act:
self.create_cc_critic(hidden_critic, POLICY_SCOPE)
self._create_cc_critic(hidden_critic, POLICY_SCOPE)
self.create_dc_critic(hidden_critic, POLICY_SCOPE)
self._create_dc_critic(hidden_critic, POLICY_SCOPE)
if self.use_recurrent:
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out]

正在加载...
取消
保存