entropies = entropies.squeeze(-1)
all_probs = None
else:
all_probs = torch.cat(all_probs, dim=-1)
all_probs = torch.cat(all_probs_list, dim=-1)
return log_probs, entropies, all_probs