|
|
|
|
|
|
expert_estimate, expert_mu = self.compute_estimate( |
|
|
|
expert_batch, use_vail_noise=True |
|
|
|
) |
|
|
|
stats_dict["Policy/GAIL Policy Estimate"] = ModelUtils.to_numpy( |
|
|
|
policy_estimate.mean() |
|
|
|
) |
|
|
|
stats_dict["Policy/GAIL Expert Estimate"] = ModelUtils.to_numpy( |
|
|
|
expert_estimate.mean() |
|
|
|
) |
|
|
|
stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item() |
|
|
|
stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item() |
|
|
|
stats_dict["Losses/GAIL Loss"] = ModelUtils.to_numpy(discriminator_loss) |
|
|
|
stats_dict["Losses/GAIL Loss"] = discriminator_loss.item() |
|
|
|
total_loss += discriminator_loss |
|
|
|
if self._settings.use_vail: |
|
|
|
# KL divergence loss (encourage latent representation to be normal) |
|
|
|
|
|
|
torch.tensor(0.0), |
|
|
|
) |
|
|
|
total_loss += vail_loss |
|
|
|
stats_dict["Policy/GAIL Beta"] = ModelUtils.to_numpy(self._beta) |
|
|
|
stats_dict["Losses/GAIL KL Loss"] = ModelUtils.to_numpy(kl_loss) |
|
|
|
stats_dict["Policy/GAIL Beta"] = self._beta.item() |
|
|
|
stats_dict["Losses/GAIL KL Loss"] = kl_loss.item() |
|
|
|
if self.gradient_penalty_weight > 0.0: |
|
|
|
total_loss += ( |
|
|
|
self.gradient_penalty_weight |
|
|
|