|
|
|
|
|
|
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001 |
|
|
|
) |
|
|
|
batch_size = 200 |
|
|
|
point_range = 3 |
|
|
|
init_error = -1.0 |
|
|
|
for _ in range(250): |
|
|
|
center = torch.rand((batch_size, size)) * point_range * 2 - point_range |
|
|
|
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range |
|
|
|
for _ in range(200): |
|
|
|
center = torch.rand((batch_size, size)) |
|
|
|
key = torch.rand((batch_size, n_k, size)) |
|
|
|
with torch.no_grad(): |
|
|
|
# create the target : The key closest to the query in euclidean distance |
|
|
|
distance = torch.sum( |
|
|
|
|
|
|
prediction = prediction.reshape((batch_size, size)) |
|
|
|
error = torch.mean((prediction - target) ** 2, dim=1) |
|
|
|
error = torch.mean(error) / 2 |
|
|
|
if init_error == -1.0: |
|
|
|
init_error = error.item() |
|
|
|
else: |
|
|
|
assert error.item() < init_error |
|
|
|
assert error.item() < 0.3 |
|
|
|
assert error.item() < 0.02 |