|
|
|
|
|
|
# throws error on runtime broadcasting due to unknown reason. We |
|
|
|
# use this to replace torch.expand() because it is not supported in |
|
|
|
# the verified version of Barracuda (1.0.X). |
|
|
|
# log_sigma = mu * 0 + self.log_sigma |
|
|
|
log_sigma = mu * 0 + torch.clamp(self.log_sigma, -20, 2) |
|
|
|
if self.tanh_squash: |
|
|
|
return TanhGaussianDistInstance(mu, torch.exp(log_sigma)) |
|
|
|