weight_decay=1e-6,
)
batch_size = 20
for _ in range(5000):
for _ in range(5):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():