def sample(self):
sample = self.mean + torch.randn_like(self.mean) * self.std
return sample
return torch.clamp(sample, -3, 3) / 3
def log_prob(self, value):
var = self.std ** 2