|
|
|
|
|
|
stats_dict["Policy/GAIL Beta"] = self._beta.item() |
|
|
|
stats_dict["Losses/GAIL KL Loss"] = kl_loss.item() |
|
|
|
if self.gradient_penalty_weight > 0.0: |
|
|
|
total_loss += ( |
|
|
|
gradient_magnitude_loss = ( |
|
|
|
stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item() |
|
|
|
total_loss += gradient_magnitude_loss |
|
|
|
return total_loss, stats_dict |
|
|
|
|
|
|
|
def compute_gradient_magnitude( |
|
|
|
|
|
|
encoder_input = obs_epsilon * policy_obs + (1 - obs_epsilon) * expert_obs |
|
|
|
if self._settings.use_actions: |
|
|
|
policy_action = self.get_action_input(policy_batch) |
|
|
|
expert_action = self.get_action_input(policy_batch) |
|
|
|
expert_action = self.get_action_input(expert_batch) |
|
|
|
action_epsilon = torch.rand(policy_action.shape) |
|
|
|
policy_dones = torch.as_tensor( |
|
|
|
policy_batch["done"], dtype=torch.float |
|
|
|
|
|
|
use_vail_noise = True |
|
|
|
z_mu = self._z_mu_layer(hidden) |
|
|
|
hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise) |
|
|
|
hidden = self._estimator(hidden) |
|
|
|
estimate = torch.mean(torch.sum(hidden, dim=1)) |
|
|
|
gradient = torch.autograd.grad(estimate, encoder_input)[0] |
|
|
|
estimate = self._estimator(hidden).squeeze(1).sum() |
|
|
|
|
|
|
|
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0] |
|
|
|
# Norm's gradient could be NaN at 0. Use our own safe_norm |
|
|
|
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt() |
|
|
|
gradient_mag = torch.mean((safe_norm - 1) ** 2) |