|
|
|
|
|
|
|
|
|
|
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|
|
|
with torch.no_grad(): |
|
|
|
self._discriminator_network.update_normalization(mini_batch) |
|
|
|
estimates, _ = self._discriminator_network.compute_estimate( |
|
|
|
mini_batch, use_vail_noise=False |
|
|
|
) |
|
|
|
|
|
|
self._settings = settings |
|
|
|
|
|
|
|
encoder_settings = NetworkSettings( |
|
|
|
normalize=False, |
|
|
|
normalize=True, |
|
|
|
hidden_units=settings.encoding_size, |
|
|
|
num_layers=2, |
|
|
|
vis_encode_type=EncoderType.SIMPLE, |
|
|
|
|
|
|
self._estimator = torch.nn.Sequential( |
|
|
|
linear_layer(estimator_input_size, 1), torch.nn.Sigmoid() |
|
|
|
) |
|
|
|
|
|
|
|
def update_normalization(self, mini_batch: AgentBuffer) -> None: |
|
|
|
""" |
|
|
|
Updates the normalization of this Discriminator's encoder. |
|
|
|
""" |
|
|
|
vec_inputs, _ = self.get_state_inputs(mini_batch) |
|
|
|
self.encoder.update_normalization(vec_inputs) |
|
|
|
|
|
|
|
def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|
|
|
""" |
|
|
|
|
|
|
hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise) |
|
|
|
estimate = self._estimator(hidden).squeeze(1).sum() |
|
|
|
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0] |
|
|
|
# print(torch.sum(gradient ** 2, dim=1)) |
|
|
|
# Norm's gradient could be NaN at 0. Use our own safe_norm |
|
|
|
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt() |
|
|
|
gradient_mag = torch.mean((safe_norm - 1) ** 2) |