|
|
|
|
|
|
) -> torch.Tensor: |
|
|
|
# Zero out masked logits, then subtract a large value. Technique mentionend here: |
|
|
|
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barracuda-friendly. |
|
|
|
mask = -1.0 * allow_mask + 1.0 |
|
|
|
adj_logits = logits * allow_mask - 1e8 * mask |
|
|
|
probs = torch.nn.functional.softmax(adj_logits, dim=-1) |
|
|
|
log_probs = torch.log(probs + EPSILON) |
|
|
|
return log_probs |
|
|
|
block_mask = -1.0 * allow_mask + 1.0 |
|
|
|
# We do -1 * tensor + constant instead of constant - tensor because it seems |
|
|
|
# Barracuda might swap the inputs of a "Sub" operation |
|
|
|
logits = logits * allow_mask - 1e8 * block_mask |
|
|
|
return logits |
|
|
|
|
|
|
|
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]: |
|
|
|
split_masks = [] |
|
|
|