|
|
|
|
|
|
import pytest |
|
|
|
from mlagents.torch_utils import torch |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
assert masks_1[0, 1] == 0 if i % 2 == 0 else 1 |
|
|
|
for i in masking_pattern_2: |
|
|
|
assert masks_2[0, 1] == 0 if i % 2 == 0 else 1 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("mask_value", [0, 1]) |
|
|
|
def test_all_masking(mask_value): |
|
|
|
# We make sure that a mask of all zeros or all ones will not trigger an error |
|
|
|
np.random.seed(1336) |
|
|
|
torch.manual_seed(1336) |
|
|
|
size, n_k, = 3, 5 |
|
|
|
embedding_size = 64 |
|
|
|
entity_embeddings = EntityEmbedding(size, n_k, embedding_size) |
|
|
|
entity_embeddings.add_self_embedding(size) |
|
|
|
transformer = ResidualSelfAttention(embedding_size, n_k) |
|
|
|
l_layer = linear_layer(embedding_size, size) |
|
|
|
optimizer = torch.optim.Adam( |
|
|
|
list(entity_embeddings.parameters()) |
|
|
|
+ list(transformer.parameters()) |
|
|
|
+ list(l_layer.parameters()), |
|
|
|
lr=0.001, |
|
|
|
weight_decay=1e-6, |
|
|
|
) |
|
|
|
batch_size = 20 |
|
|
|
for _ in range(5000): |
|
|
|
center = torch.rand((batch_size, size)) |
|
|
|
key = torch.rand((batch_size, n_k, size)) |
|
|
|
with torch.no_grad(): |
|
|
|
# create the target : The key closest to the query in euclidean distance |
|
|
|
distance = torch.sum( |
|
|
|
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2 |
|
|
|
) |
|
|
|
argmin = torch.argmin(distance, dim=1) |
|
|
|
target = [] |
|
|
|
for i in range(batch_size): |
|
|
|
target += [key[i, argmin[i], :]] |
|
|
|
target = torch.stack(target, dim=0) |
|
|
|
target = target.detach() |
|
|
|
|
|
|
|
embeddings = entity_embeddings(center, key) |
|
|
|
masks = [torch.ones_like(key[:, :, 0]) * mask_value] |
|
|
|
prediction = transformer.forward(embeddings, masks) |
|
|
|
prediction = l_layer(prediction) |
|
|
|
prediction = prediction.reshape((batch_size, size)) |
|
|
|
error = torch.mean((prediction - target) ** 2, dim=1) |
|
|
|
error = torch.mean(error) / 2 |
|
|
|
optimizer.zero_grad() |
|
|
|
error.backward() |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
|
|
|
def test_predict_closest_training(): |
|
|
|