return torch.multinomial(self.probs, 1)
def pdf(self, value):
return self.probs[:, value]
return torch.diag(self.probs.T[value.flatten()])
def log_prob(self, value):
return torch.log(self.pdf(value))