|
|
|
|
|
|
return torch.multinomial(self.probs, 1) |
|
|
|
|
|
|
|
def pdf(self, value): |
|
|
|
return torch.diag(self.probs.T[value.flatten().long()]) |
|
|
|
idx = torch.range(end=len(value)).unsqueeze(-1) |
|
|
|
return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1) |
|
|
|
|
|
|
|
def log_prob(self, value): |
|
|
|
return torch.log(self.pdf(value)) |
|
|
|