浏览代码

_

/exp-tanh
vincentpierre 4 年前
当前提交
a4b78d53
共有 1 个文件被更改,包括 8 次插入8 次删除
  1. 16
      ml-agents/mlagents/trainers/torch/distributions.py

16
ml-agents/mlagents/trainers/torch/distributions.py


squashed = self.transform(unsquashed_sample)
return squashed
# def _inverse_tanh(self, value):
# # capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
# capped_value = (1-EPSILON) * value
# return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
def _inverse_tanh(self, value):
return 0.5 * torch.log((1 + value) / (1 - value) + EPSILON)
unsquashed = self.transform.inv(value)
unsquashed = self._inverse_tanh(value * 0.95)
unsquashed, None
unsquashed, value
)
# print("tmp decomposition", value, capped_value, unsquashed, super().log_prob(unsquashed) , self.transform.log_abs_det_jacobian(
# unsquashed, None

kernel_gain=0.2,
bias_init=Initialization.Zero,
)
torch.nn.init.constant_(self.log_sigma.bias.data, -1)
torch.zeros(1, num_outputs, requires_grad=True)
torch.ones(1, num_outputs, requires_grad=True)
torch.nn.init.constant_(self.log_sigma.data, -1)
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)

if torch.isnan(torch.mean(log_sigma)):
print("GaussianDistribution log sigma NaN")
if self.tanh_squash:
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
else:

正在加载...
取消
保存