|
|
|
|
|
|
) |
|
|
|
|
|
|
|
self.optimizer = torch.optim.Adam( |
|
|
|
params, lr=self.trainer_settings.hyperparameters.learning_rate, weight_decay=1e-6 |
|
|
|
params, |
|
|
|
lr=self.trainer_settings.hyperparameters.learning_rate, |
|
|
|
weight_decay=1e-6, |
|
|
|
) |
|
|
|
self.stats_name_to_update_name = { |
|
|
|
"Losses/Value Loss": "value_loss", |
|
|
|
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
return value_loss |
|
|
|
|
|
|
|
def coma_regularizer_loss(self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor]): |
|
|
|
def coma_regularizer_loss( |
|
|
|
self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor] |
|
|
|
): |
|
|
|
reg_losses = [] |
|
|
|
for name, head in values.items(): |
|
|
|
reg_loss = torch.nn.functional.mse_loss(head, baseline_values[name]) |
|
|
|
|
|
|
|
|
|
|
# Regularizer loss reduces bias between the baseline and values. Other |
|
|
|
# regularizers are possible here. |
|
|
|
regularizer_loss = self.coma_regularizer_loss(values, baseline_vals) |
|
|
|
# regularizer_loss = self.coma_regularizer_loss(values, baseline_vals) |
|
|
|
|
|
|
|
policy_loss = self.ppo_policy_loss( |
|
|
|
ModelUtils.list_to_tensor(batch["advantages"]), |
|
|
|
|
|
|
) |
|
|
|
loss = ( |
|
|
|
policy_loss |
|
|
|
+ 0.25 * (value_loss + baseline_loss) |
|
|
|
+ 0.25 * regularizer_loss |
|
|
|
+ 0.5 * (value_loss + 0.5 * baseline_loss) |
|
|
|
# + 0.25 * regularizer_loss |
|
|
|
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
"Losses/Value Loss": value_loss.item(), |
|
|
|
# "Losses/Q Loss": q_loss.item(), |
|
|
|
"Losses/Baseline Value Loss": baseline_loss.item(), |
|
|
|
"Losses/Regularization Loss": regularizer_loss.item(), |
|
|
|
# "Losses/Regularization Loss": regularizer_loss.item(), |
|
|
|
"Policy/Learning Rate": decay_lr, |
|
|
|
"Policy/Epsilon": decay_eps, |
|
|
|
"Policy/Beta": decay_bet, |
|
|
|