|
|
|
|
|
|
) |
|
|
|
|
|
|
|
self.optimizer = torch.optim.Adam( |
|
|
|
params, lr=self.trainer_settings.hyperparameters.learning_rate, weight_decay=1e-6 |
|
|
|
params, |
|
|
|
lr=self.trainer_settings.hyperparameters.learning_rate, |
|
|
|
weight_decay=1e-6, |
|
|
|
) |
|
|
|
self.stats_name_to_update_name = { |
|
|
|
"Losses/Value Loss": "value_loss", |
|
|
|
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
return value_loss |
|
|
|
|
|
|
|
def coma_regularizer_loss(self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor]): |
|
|
|
def coma_regularizer_loss( |
|
|
|
self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor] |
|
|
|
): |
|
|
|
reg_losses = [] |
|
|
|
for name, head in values.items(): |
|
|
|
reg_loss = torch.nn.functional.mse_loss(head, baseline_values[name]) |
|
|
|