|
|
|
|
|
|
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) |
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks) |
|
|
|
|
|
|
|
total_value_loss = q1_loss + q2_loss + value_loss |
|
|
|
if self.policy.shared_critic: |
|
|
|
policy_loss += value_loss |
|
|
|
total_value_loss = q1_loss + q2_loss |
|
|
|
else: |
|
|
|
total_value_loss = q1_loss + q2_loss + value_loss |
|
|
|
|
|
|
|
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|
|
|
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) |
|
|
|