浏览代码

[add-fire] Halve Gaussian entropy (#4319)

* Halve entropy

* Fix utils test
/develop/add-fire
GitHub 4 年前
当前提交
69d29b86
共有 3 个文件被更改,包括 4 次插入4 次删除
  1. 4
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  2. 2
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  3. 2
      ml-agents/mlagents/trainers/torch/distributions.py

4
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


assert log_prob == pytest.approx(-0.919, abs=0.01)
for ent in dist_instance.entropy().flatten():
# entropy of standard normal at 0
assert ent == pytest.approx(2.83, abs=0.01)
# entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma)
assert ent == pytest.approx(1.42, abs=0.01)
def test_tanh_gaussian_dist_instance():

2
ml-agents/mlagents/trainers/tests/torch/test_utils.py


for ent in entropies.flatten():
# entropy of standard normal at 0
assert ent == pytest.approx(2.83, abs=0.01)
assert ent == pytest.approx(1.42, abs=0.01)
# Test continuous
# Add two dists to the list.

2
ml-agents/mlagents/trainers/torch/distributions.py


return torch.exp(log_prob)
def entropy(self):
return torch.log(2 * math.pi * math.e * self.std + EPSILON)
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON)
class TanhGaussianDistInstance(GaussianDistInstance):

正在加载...
取消
保存