|
|
|
|
|
|
expert_estimate, expert_mu = self.compute_estimate( |
|
|
|
expert_batch, use_vail_noise=True |
|
|
|
) |
|
|
|
stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item() |
|
|
|
stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item() |
|
|
|
stats_dict["Policy/GAIL Policy Estimate"] = ModelUtils.to_item(policy_estimate.mean()) |
|
|
|
stats_dict["Policy/GAIL Expert Estimate"] = ModelUtils.to_item(expert_estimate.mean()) |
|
|
|
stats_dict["Losses/GAIL Loss"] = discriminator_loss.item() |
|
|
|
stats_dict["Losses/GAIL Loss"] = ModelUtils.to_item(discriminator_loss) |
|
|
|
total_loss += discriminator_loss |
|
|
|
if self._settings.use_vail: |
|
|
|
# KL divergence loss (encourage latent representation to be normal) |
|
|
|
|
|
|
torch.tensor(0.0), |
|
|
|
) |
|
|
|
total_loss += vail_loss |
|
|
|
stats_dict["Policy/GAIL Beta"] = self._beta.item() |
|
|
|
stats_dict["Losses/GAIL KL Loss"] = kl_loss.item() |
|
|
|
stats_dict["Policy/GAIL Beta"] = ModelUtils.to_item(self._beta) |
|
|
|
stats_dict["Losses/GAIL KL Loss"] = ModelUtils.to_item(kl_loss) |
|
|
|
stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item() |
|
|
|
stats_dict["Policy/GAIL Grad Mag Loss"] = ModelUtils.to_item(gradient_magnitude_loss) |
|
|
|
total_loss += gradient_magnitude_loss |
|
|
|
return total_loss, stats_dict |
|
|
|
|
|
|
|