|
|
|
|
|
|
return torch.multinomial(self.probs, 1) |
|
|
|
|
|
|
|
def pdf(self, value): |
|
|
|
idx = torch.range(end=len(value)).unsqueeze(-1) |
|
|
|
return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1) |
|
|
|
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1) |
|
|
|
return torch.gather( |
|
|
|
self.probs.permute(1, 0)[value.flatten().long()], -1, idx |
|
|
|
).squeeze(-1) |
|
|
|
|
|
|
|
def log_prob(self, value): |
|
|
|
return torch.log(self.pdf(value)) |
|
|
|