|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mlagents.trainers.torch.layers import linear_layer |
|
|
|
from mlagents.trainers.torch.attention import MultiHeadAttention, SimpleTransformer |
|
|
|
from mlagents.trainers.torch.attention import ( |
|
|
|
MultiHeadAttention, |
|
|
|
EntityEmbeddings, |
|
|
|
ResidualSelfAttention, |
|
|
|
) |
|
|
|
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12 |
|
|
|
n_h, emb_size = 4, 12 |
|
|
|
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size) |
|
|
|
mha = MultiHeadAttention(emb_size, n_h) |
|
|
|
query = torch.ones((b, n_q, q_size)) |
|
|
|
key = torch.ones((b, n_k, k_size)) |
|
|
|
value = torch.ones((b, n_k, v_size)) |
|
|
|
query = torch.ones((b, n_q, emb_size)) |
|
|
|
key = torch.ones((b, n_k, emb_size)) |
|
|
|
value = torch.ones((b, n_k, emb_size)) |
|
|
|
output, attention = mha.forward(query, key, value) |
|
|
|
output, attention = mha.forward(query, key, value, n_q, n_k) |
|
|
|
assert output.shape == (b, n_q, o_size) |
|
|
|
assert output.shape == (b, n_q, emb_size) |
|
|
|
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12 |
|
|
|
n_h, emb_size = 4, 12 |
|
|
|
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size) |
|
|
|
mha = MultiHeadAttention(emb_size, n_h) |
|
|
|
# create a key input with some keys all 0 |
|
|
|
query = torch.ones((b, n_q, emb_size)) |
|
|
|
key = torch.ones((b, n_k, emb_size)) |
|
|
|
value = torch.ones((b, n_k, emb_size)) |
|
|
|
# create a key input with some keys all 0 |
|
|
|
key = torch.ones((b, n_k, k_size)) |
|
|
|
mask = torch.zeros((b, n_k)) |
|
|
|
for i in range(n_k): |
|
|
|
if i % 3 == 0: |
|
|
|
|
|
|
query = torch.ones((b, n_q, q_size)) |
|
|
|
value = torch.ones((b, n_k, v_size)) |
|
|
|
_, attention = mha.forward(query, key, value, n_q, n_k, mask) |
|
|
|
_, attention = mha.forward(query, key, value, mask) |
|
|
|
for i in range(n_k): |
|
|
|
if i % 3 == 0: |
|
|
|
assert torch.sum(attention[:, :, :, i] ** 2) < epsilon |
|
|
|
|
|
|
|
|
|
|
def test_multi_head_attention_training(): |
|
|
|
np.random.seed(1336) |
|
|
|
torch.manual_seed(1336) |
|
|
|
size, n_h, n_k, n_q = 3, 10, 5, 1 |
|
|
|
embedding_size = 64 |
|
|
|
mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size) |
|
|
|
optimizer = torch.optim.Adam(mha.parameters(), lr=0.001) |
|
|
|
batch_size = 200 |
|
|
|
point_range = 3 |
|
|
|
init_error = -1.0 |
|
|
|
for _ in range(50): |
|
|
|
query = torch.rand((batch_size, n_q, size)) * point_range * 2 - point_range |
|
|
|
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range |
|
|
|
value = key |
|
|
|
with torch.no_grad(): |
|
|
|
# create the target : The key closest to the query in euclidean distance |
|
|
|
distance = torch.sum((query - key) ** 2, dim=2) |
|
|
|
argmin = torch.argmin(distance, dim=1) |
|
|
|
target = [] |
|
|
|
for i in range(batch_size): |
|
|
|
target += [key[i, argmin[i], :]] |
|
|
|
target = torch.stack(target, dim=0) |
|
|
|
target = target.detach() |
|
|
|
|
|
|
|
prediction, _ = mha.forward(query, key, value) |
|
|
|
prediction = prediction.reshape((batch_size, size)) |
|
|
|
error = torch.mean((prediction - target) ** 2, dim=1) |
|
|
|
error = torch.mean(error) / 2 |
|
|
|
if init_error == -1.0: |
|
|
|
init_error = error.item() |
|
|
|
else: |
|
|
|
assert error.item() < init_error |
|
|
|
print(error.item()) |
|
|
|
optimizer.zero_grad() |
|
|
|
error.backward() |
|
|
|
optimizer.step() |
|
|
|
assert error.item() < 0.5 |
|
|
|
|
|
|
|
|
|
|
|
def test_zero_mask_layer(): |
|
|
|
batch_size, size = 10, 30 |
|
|
|
|
|
|
|
|
|
|
input_1 = generate_input_helper(masking_pattern_1) |
|
|
|
input_2 = generate_input_helper(masking_pattern_2) |
|
|
|
|
|
|
|
masks = SimpleTransformer.get_masks([input_1, input_2]) |
|
|
|
masks = EntityEmbeddings.get_masks([input_1, input_2]) |
|
|
|
assert len(masks) == 2 |
|
|
|
masks_1 = masks[0] |
|
|
|
masks_2 = masks[1] |
|
|
|
|
|
|
torch.manual_seed(1336) |
|
|
|
size, n_k, = 3, 5 |
|
|
|
embedding_size = 64 |
|
|
|
transformer = SimpleTransformer(size, [size], embedding_size) |
|
|
|
entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size) |
|
|
|
transformer = ResidualSelfAttention(embedding_size, [n_k]) |
|
|
|
l_layer = linear_layer(embedding_size, size) |
|
|
|
optimizer = torch.optim.Adam( |
|
|
|
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001 |
|
|
|
|
|
|
init_error = -1.0 |
|
|
|
for _ in range(100): |
|
|
|
for _ in range(250): |
|
|
|
center = torch.rand((batch_size, size)) * point_range * 2 - point_range |
|
|
|
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
target = torch.stack(target, dim=0) |
|
|
|
target = target.detach() |
|
|
|
|
|
|
|
masks = SimpleTransformer.get_masks([key]) |
|
|
|
prediction = transformer.forward(center, [key], masks) |
|
|
|
embeddings = entity_embeddings(center, [key]) |
|
|
|
masks = EntityEmbeddings.get_masks([key]) |
|
|
|
prediction = transformer.forward(embeddings, masks) |
|
|
|
prediction = l_layer(prediction) |
|
|
|
prediction = prediction.reshape((batch_size, size)) |
|
|
|
error = torch.mean((prediction - target) ** 2, dim=1) |
|
|
|