|
|
|
|
|
|
for i, (_lp, _qt) in enumerate( |
|
|
|
zip(branched_per_action_ent, branched_q_term) |
|
|
|
) |
|
|
|
] |
|
|
|
], |
|
|
|
dim=1, |
|
|
|
policy_loss = torch.mean(loss_masks * batch_policy_loss) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
return policy_loss |
|
|
|
|
|
|
|
def sac_entropy_loss( |
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) |
|
|
|
entropy_loss = -torch.mean( |
|
|
|
self._log_ent_coef * loss_masks * target_current_diff |
|
|
|
entropy_loss = -1 * ModelUtils.masked_mean( |
|
|
|
self._log_ent_coef * target_current_diff, loss_masks |
|
|
|
) |
|
|
|
else: |
|
|
|
with torch.no_grad(): |
|
|
|