|
|
|
|
|
|
branches.append(branch_output_layer) |
|
|
|
return nn.ModuleList(branches) |
|
|
|
|
|
|
|
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
|
|
def _mask_branch( |
|
|
|
self, logits: torch.Tensor, allow_mask: torch.Tensor |
|
|
|
) -> torch.Tensor: |
|
|
|
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barrcuda-friendly. |
|
|
|
flipped_mask = 1.0 - mask |
|
|
|
adj_logits = logits * mask - 1e8 * flipped_mask |
|
|
|
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barracuda-friendly. |
|
|
|
mask = -1.0 * allow_mask + 1.0 |
|
|
|
adj_logits = logits * allow_mask - 1e8 * mask |
|
|
|
probs = torch.nn.functional.softmax(adj_logits, dim=-1) |
|
|
|
log_probs = torch.log(probs + EPSILON) |
|
|
|
return log_probs |
|
|
|