浏览代码

fix sac shared

/develop/action-slice
Andrew Cohen 4 年前
当前提交
00b891df
共有 2 个文件被更改,包括 6 次插入2 次删除
  1. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 6
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

2
ml-agents/mlagents/trainers/policy/torch_policy.py


conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.shared_critic = False
self.shared_critic = True
# Save the m_size needed for export
self._export_m_size = self.m_size

6
ml-agents/mlagents/trainers/sac/optimizer_torch.py


policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
entropy_loss = self.sac_entropy_loss(log_probs, masks)
total_value_loss = q1_loss + q2_loss + value_loss
if self.policy.shared_critic:
policy_loss += value_loss
total_value_loss = q1_loss + q2_loss
else:
total_value_loss = q1_loss + q2_loss + value_loss
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)

正在加载...
取消
保存