|
|
|
|
|
|
ModelUtils.soft_update( |
|
|
|
self.policy.actor_critic.critic, self.target_network, 1.0 |
|
|
|
) |
|
|
|
|
|
|
|
self.alpha = 0.9 |
|
|
|
self._log_ent_coef = torch.nn.Parameter( |
|
|
|
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), |
|
|
|
requires_grad=True, |
|
|
|
|
|
|
target_values: Dict[str, torch.Tensor], |
|
|
|
dones: torch.Tensor, |
|
|
|
rewards: Dict[str, torch.Tensor], |
|
|
|
log_probs: torch.Tensor, |
|
|
|
loss_masks: torch.Tensor, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
q1_losses = [] |
|
|
|
|
|
|
q1_stream = q1_out[name].squeeze() |
|
|
|
q2_stream = q2_out[name].squeeze() |
|
|
|
with torch.no_grad(): |
|
|
|
q_backup = rewards[name] + ( |
|
|
|
(1.0 - self.use_dones_in_backup[name] * dones) |
|
|
|
* self.gammas[i] |
|
|
|
* target_values[name] |
|
|
|
_ent_coef = torch.exp(self._log_ent_coef) |
|
|
|
policy_bonus = ( |
|
|
|
self.alpha |
|
|
|
* _ent_coef |
|
|
|
* torch.sum((log_probs * log_probs.exp()), dim=1) |
|
|
|
) |
|
|
|
q_backup = ( |
|
|
|
rewards[name] |
|
|
|
+ policy_bonus |
|
|
|
+ ( |
|
|
|
(1.0 - self.use_dones_in_backup[name] * dones) |
|
|
|
* self.gammas[i] |
|
|
|
* target_values[name] |
|
|
|
) |
|
|
|
) |
|
|
|
_q1_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks |
|
|
|
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
|
|
|
|
q1_loss, q2_loss = self.sac_q_loss( |
|
|
|
q1_stream, q2_stream, target_values, dones, rewards, masks |
|
|
|
q1_stream, q2_stream, target_values, dones, rewards, log_probs, masks |
|
|
|
) |
|
|
|
value_loss = self.sac_value_loss( |
|
|
|
log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete |
|
|
|