|
|
|
|
|
|
* self.gammas[i] |
|
|
|
* target_values[name] |
|
|
|
) |
|
|
|
_q1_loss = 0.5 * torch.mean( |
|
|
|
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream) |
|
|
|
_q1_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks |
|
|
|
_q2_loss = 0.5 * torch.mean( |
|
|
|
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream) |
|
|
|
_q2_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks |
|
|
|
) |
|
|
|
|
|
|
|
q1_losses.append(_q1_loss) |
|
|
|
|
|
|
v_backup = min_policy_qs[name] - torch.sum( |
|
|
|
_ent_coef * log_probs, dim=1 |
|
|
|
) |
|
|
|
value_loss = 0.5 * torch.mean( |
|
|
|
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks |
|
|
|
) |
|
|
|
value_losses.append(value_loss) |
|
|
|
else: |
|
|
|
|
|
|
v_backup = min_policy_qs[name] - torch.mean( |
|
|
|
branched_ent_bonus, axis=0 |
|
|
|
) |
|
|
|
value_loss = 0.5 * torch.mean( |
|
|
|
loss_masks |
|
|
|
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze()) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.pan><span class="n">nn.functional.mse_loss(values[namen><span class="p">], v_backup.squeeze()), |
|
|
|
loss_masks, |
|
|
|
) |
|
|
|
value_losses.append(value_loss) |
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
|
|
|
if not discrete: |
|
|
|
mean_q1 = mean_q1.unsqueeze(1) |
|
|
|
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) |
|
|
|
policy_loss = torch.mean(loss_masks * batch_policy_loss) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
else: |
|
|
|
action_probs = log_probs.exp() |
|
|
|
branched_per_action_ent = ModelUtils.break_into_branches( |
|
|
|
|
|
|
target_current_diff = torch.squeeze( |
|
|
|
target_current_diff_branched, axis=2 |
|
|
|
) |
|
|
|
entropy_loss = -torch.mean( |
|
|
|
loss_masks |
|
|
|
* torch.mean(self._log_ent_coef * target_current_diff, axis=1) |
|
|
|
entropy_loss = -1 * ModelUtils.masked_mean( |
|
|
|
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks |
|
|
|
) |
|
|
|
|
|
|
|
return entropy_loss |
|
|
|
|
|
|
memories=next_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
|
|
|
|