浏览代码

Use clipped gaussian distribution for entropy calculation

/exp-clipped-gaussian-entropy
vincentpierre 4 年前
当前提交
811319c0
共有 1 个文件被更改,包括 41 次插入2 次删除
  1. 43
      ml-agents/mlagents/trainers/torch/distributions.py

43
ml-agents/mlagents/trainers/torch/distributions.py


return torch.exp(log_prob)
def entropy(self):
# The entropy is calculated using the clipped gaussian entropy:
# https://en.wikipedia.org/wiki/Truncated_normal_distribution
min_bound = -1
max_bound = 1
alpha = (min_bound - self.mean) / (self.std + EPSILON)
beta = (max_bound - self.mean) / (self.std + EPSILON)
# 0.95 is a computation trick to keep the maximum sigma within bounds
Z = (
0.95
* 0.5
* (
self._erf_approximation(beta / math.sqrt(2))
- self._erf_approximation(alpha / math.sqrt(2))
)
)
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 + EPSILON),
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 * Z ** 2 + EPSILON)
+ (alpha * self._phi(alpha) - beta * self._phi(beta)) / (2 * Z + EPSILON),
) # Use equivalent behavior to TF
)
def _erf_approximation(self, x):
# using Polynomial approximation : https://en.wikipedia.org/wiki/Error_function
t = 1 / (1 + 0.5 * torch.abs(x))
tau = t * torch.exp(
-x ** 2
- 1.26551223
+ 1.00002368 * t
+ 0.37409196 * t ** 2
+ 0.09678418 * t ** 3
- 0.18628806 * t ** 4
+ 0.27886807 * t ** 5
- 1.13520398 * t ** 6
+ 1.48851587 * t ** 7
- 0.82215223 * t ** 8
+ 0.17087277 * t ** 9
)
return (1 - tau) * torch.sign(x)
def _phi(self, x):
# standard normal distribution
return 1 / (math.sqrt(2 * math.pi)) * torch.exp(-0.5 * x ** 2)
def exported_model_output(self):
return self.sample()

正在加载...
取消
保存