|
|
|
|
|
|
estimates, _ = self._discriminator_network.compute_estimate( |
|
|
|
mini_batch, use_vail_noise=False |
|
|
|
) |
|
|
|
return ( |
|
|
|
return ModelUtils.to_numpy( |
|
|
|
.detach() |
|
|
|
.cpu() |
|
|
|
.numpy() |
|
|
|
) |
|
|
|
|
|
|
|
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|
|
|
|
|
|
expert_estimate, expert_mu = self.compute_estimate( |
|
|
|
expert_batch, use_vail_noise=True |
|
|
|
) |
|
|
|
stats_dict["Policy/GAIL Policy Estimate"] = ( |
|
|
|
policy_estimate.mean().detach().cpu().numpy() |
|
|
|
stats_dict["Policy/GAIL Policy Estimate"] = ModelUtils.to_numpy( |
|
|
|
policy_estimate.mean() |
|
|
|
stats_dict["Policy/GAIL Expert Estimate"] = ( |
|
|
|
expert_estimate.mean().detach().cpu().numpy() |
|
|
|
stats_dict["Policy/GAIL Expert Estimate"] = ModelUtils.to_numpy( |
|
|
|
expert_estimate.mean() |
|
|
|
stats_dict["Losses/GAIL Loss"] = discriminator_loss.detach().cpu().numpy() |
|
|
|
stats_dict["Losses/GAIL Loss"] = ModelUtils.to_numpy(discriminator_loss) |
|
|
|
total_loss += discriminator_loss |
|
|
|
if self._settings.use_vail: |
|
|
|
# KL divergence loss (encourage latent representation to be normal) |
|
|
|
|
|
|
torch.tensor(0.0), |
|
|
|
) |
|
|
|
total_loss += vail_loss |
|
|
|
stats_dict["Policy/GAIL Beta"] = self._beta.detach().cpu().numpy() |
|
|
|
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy() |
|
|
|
stats_dict["Policy/GAIL Beta"] = ModelUtils.to_numpy(self._beta) |
|
|
|
stats_dict["Losses/GAIL KL Loss"] = ModelUtils.to_numpy(kl_loss) |
|
|
|
if self.gradient_penalty_weight > 0.0: |
|
|
|
total_loss += ( |
|
|
|
self.gradient_penalty_weight |
|
|
|