|
|
|
|
|
|
all_probs = None |
|
|
|
else: |
|
|
|
all_probs = torch.cat(all_probs_list, dim=-1) |
|
|
|
entropy_sum = torch.sum(entropies, dim=1) |
|
|
|
return log_probs, entropy_sum, all_probs |
|
|
|
return log_probs, entropies, all_probs |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|
|
|