|
|
|
|
|
|
from mlagents.trainers.torch.layers import linear_layer |
|
|
|
from mlagents.trainers.torch.attention import ( |
|
|
|
MultiHeadAttention, |
|
|
|
EntityEmbeddings, |
|
|
|
EntityEmbedding, |
|
|
|
ResidualSelfAttention, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
input_1 = generate_input_helper(masking_pattern_1) |
|
|
|
input_2 = generate_input_helper(masking_pattern_2) |
|
|
|
|
|
|
|
masks = EntityEmbeddings.get_masks([input_1, input_2]) |
|
|
|
masks = ResidualSelfAttention.get_masks([input_1, input_2]) |
|
|
|
assert len(masks) == 2 |
|
|
|
masks_1 = masks[0] |
|
|
|
masks_2 = masks[1] |
|
|
|
|
|
|
torch.manual_seed(1336) |
|
|
|
size, n_k, = 3, 5 |
|
|
|
embedding_size = 64 |
|
|
|
entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size) |
|
|
|
entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size) |
|
|
|
transformer = ResidualSelfAttention(embedding_size, n_k) |
|
|
|
l_layer = linear_layer(embedding_size, size) |
|
|
|
optimizer = torch.optim.Adam( |
|
|
|
|
|
|
target = torch.stack(target, dim=0) |
|
|
|
target = target.detach() |
|
|
|
|
|
|
|
embeddings = entity_embeddings(center, [key]) |
|
|
|
masks = EntityEmbeddings.get_masks([key]) |
|
|
|
embeddings = entity_embeddings(center, key) |
|
|
|
masks = ResidualSelfAttention.get_masks([key]) |
|
|
|
prediction = transformer.forward(embeddings, masks) |
|
|
|
prediction = l_layer(prediction) |
|
|
|
prediction = prediction.reshape((batch_size, size)) |
|
|
|