|
|
|
|
|
|
return torch.exp(log_prob) |
|
|
|
|
|
|
|
def entropy(self): |
|
|
|
# The entropy is calculated using the clipped gaussian entropy: |
|
|
|
# https://en.wikipedia.org/wiki/Truncated_normal_distribution |
|
|
|
min_bound = -1 |
|
|
|
max_bound = 1 |
|
|
|
alpha = (min_bound - self.mean) / (self.std + EPSILON) |
|
|
|
beta = (max_bound - self.mean) / (self.std + EPSILON) |
|
|
|
|
|
|
|
# 0.95 is a computation trick to keep the maximum sigma within bounds |
|
|
|
Z = ( |
|
|
|
0.95 |
|
|
|
* 0.5 |
|
|
|
* ( |
|
|
|
self._erf_approximation(beta / math.sqrt(2)) |
|
|
|
- self._erf_approximation(alpha / math.sqrt(2)) |
|
|
|
) |
|
|
|
) |
|
|
|
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 + EPSILON), |
|
|
|
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 * Z ** 2 + EPSILON) |
|
|
|
+ (alpha * self._phi(alpha) - beta * self._phi(beta)) / (2 * Z + EPSILON), |
|
|
|
) # Use equivalent behavior to TF |
|
|
|
) |
|
|
|
|
|
|
|
def _erf_approximation(self, x): |
|
|
|
# using Polynomial approximation : https://en.wikipedia.org/wiki/Error_function |
|
|
|
t = 1 / (1 + 0.5 * torch.abs(x)) |
|
|
|
tau = t * torch.exp( |
|
|
|
-x ** 2 |
|
|
|
- 1.26551223 |
|
|
|
+ 1.00002368 * t |
|
|
|
+ 0.37409196 * t ** 2 |
|
|
|
+ 0.09678418 * t ** 3 |
|
|
|
- 0.18628806 * t ** 4 |
|
|
|
+ 0.27886807 * t ** 5 |
|
|
|
- 1.13520398 * t ** 6 |
|
|
|
+ 1.48851587 * t ** 7 |
|
|
|
- 0.82215223 * t ** 8 |
|
|
|
+ 0.17087277 * t ** 9 |
|
|
|
) |
|
|
|
return (1 - tau) * torch.sign(x) |
|
|
|
|
|
|
|
def _phi(self, x): |
|
|
|
# standard normal distribution |
|
|
|
return 1 / (math.sqrt(2 * math.pi)) * torch.exp(-0.5 * x ** 2) |
|
|
|
|
|
|
|
def exported_model_output(self): |
|
|
|
return self.sample() |
|
|
|