|
|
|
|
|
|
branches.append(branch_output_layer) |
|
|
|
return nn.ModuleList(branches) |
|
|
|
|
|
|
|
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
|
|
def _mask_branch( |
|
|
|
self, logits: torch.Tensor, allow_mask: torch.Tensor |
|
|
|
) -> torch.Tensor: |
|
|
|
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barrcuda-friendly. |
|
|
|
flipped_mask = 1.0 - mask |
|
|
|
adj_logits = logits * mask - 1e8 * flipped_mask |
|
|
|
probs = torch.nn.functional.softmax(adj_logits, dim=-1) |
|
|
|
log_probs = torch.log(probs + EPSILON) |
|
|
|
return log_probs |
|
|
|
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barracuda-friendly. |
|
|
|
block_mask = -1.0 * allow_mask + 1.0 |
|
|
|
# We do -1 * tensor + constant instead of constant - tensor because it seems |
|
|
|
# Barracuda might swap the inputs of a "Sub" operation |
|
|
|
logits = logits * allow_mask - 1e8 * block_mask |
|
|
|
return logits |
|
|
|
|
|
|
|
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]: |
|
|
|
split_masks = [] |
|
|
|