浏览代码

Normalize GAIL observations

/develop/gail-norm
Ervin Teng 4 年前
当前提交
bc746839
共有 1 个文件被更改,包括 10 次插入1 次删除
  1. 11
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py

11
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
with torch.no_grad():
self._discriminator_network.update_normalization(mini_batch)
estimates, _ = self._discriminator_network.compute_estimate(
mini_batch, use_vail_noise=False
)

self._settings = settings
encoder_settings = NetworkSettings(
normalize=False,
normalize=True,
hidden_units=settings.encoding_size,
num_layers=2,
vis_encode_type=EncoderType.SIMPLE,

self._estimator = torch.nn.Sequential(
linear_layer(estimator_input_size, 1), torch.nn.Sigmoid()
)
def update_normalization(self, mini_batch: AgentBuffer) -> None:
"""
Updates the normalization of this Discriminator's encoder.
"""
vec_inputs, _ = self.get_state_inputs(mini_batch)
self.encoder.update_normalization(vec_inputs)
def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""

hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
estimate = self._estimator(hidden).squeeze(1).sum()
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0]
# print(torch.sum(gradient ** 2, dim=1))
# Norm's gradient could be NaN at 0. Use our own safe_norm
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
gradient_mag = torch.mean((safe_norm - 1) ** 2)
正在加载...
取消
保存