|
|
|
|
|
|
|
|
|
|
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
|
|
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask |
|
|
|
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1) |
|
|
|
normalized_probs = raw_probs / ( |
|
|
|
torch.sum(raw_probs, dim=-1).unsqueeze(-1) + EPSILON |
|
|
|
) |
|
|
|
normalized_logits = torch.log(normalized_probs + EPSILON) |
|
|
|
return normalized_logits |
|
|
|
|
|
|
|