|
|
|
|
|
|
logits = branch(inputs) |
|
|
|
norm_logits = self._mask_branch(logits, masks[idx]) |
|
|
|
distribution = torch.nn.functional.gumbel_softmax( |
|
|
|
norm_logits, hard=False, dim=1 |
|
|
|
norm_logits, hard=True, dim=1 |
|
|
|
) |
|
|
|
branch_distributions.append(distribution) |
|
|
|
return branch_distributions |
|
|
|