|
|
|
|
|
|
return torch.multinomial(self.probs, 1) |
|
|
|
|
|
|
|
def pdf(self, value): |
|
|
|
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]), |
|
|
|
# but torch.diag is not supported by ONNX export. |
|
|
|
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1) |
|
|
|
return torch.gather( |
|
|
|
self.probs.permute(1, 0)[value.flatten().long()], -1, idx |
|
|
|