def sample(self):
sample = self.mean + torch.randn_like(self.mean) * self.std
return torch.clamp(sample, -3, 3) / 3
return sample / 3
def log_prob(self, value):
unscaled_val = value * 3 # Inverse of the clipping