浏览代码

[Bug fix] Fix bug in GAIL gradient penalty (#4425)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
7b4d0865
共有 2 个文件被更改,包括 8 次插入6 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  2. 12
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py

2
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


buffer_policy = create_agent_buffer(behavior_spec, 1000)
demo_to_buffer.return_value = None, buffer_expert
gail_settings = GAILSettings(
demo_path="", learning_rate=0.05, use_vail=False, use_actions=use_actions
demo_path="", learning_rate=0.005, use_vail=False, use_actions=use_actions
)
gail_rp = create_reward_provider(
RewardSignalType.GAIL, behavior_spec, gail_settings

12
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


stats_dict["Policy/GAIL Beta"] = self._beta.item()
stats_dict["Losses/GAIL KL Loss"] = kl_loss.item()
if self.gradient_penalty_weight > 0.0:
total_loss += (
gradient_magnitude_loss = (
stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item()
total_loss += gradient_magnitude_loss
return total_loss, stats_dict
def compute_gradient_magnitude(

encoder_input = obs_epsilon * policy_obs + (1 - obs_epsilon) * expert_obs
if self._settings.use_actions:
policy_action = self.get_action_input(policy_batch)
expert_action = self.get_action_input(policy_batch)
expert_action = self.get_action_input(expert_batch)
action_epsilon = torch.rand(policy_action.shape)
policy_dones = torch.as_tensor(
policy_batch["done"], dtype=torch.float

use_vail_noise = True
z_mu = self._z_mu_layer(hidden)
hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
hidden = self._estimator(hidden)
estimate = torch.mean(torch.sum(hidden, dim=1))
gradient = torch.autograd.grad(estimate, encoder_input)[0]
estimate = self._estimator(hidden).squeeze(1).sum()
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0]
# Norm's gradient could be NaN at 0. Use our own safe_norm
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
gradient_mag = torch.mean((safe_norm - 1) ** 2)
正在加载...
取消
保存