|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mlagents.trainers.torch.layers import linear_layer |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.layers import linear_layer, LinearEncoder |
|
|
|
from mlagents.trainers.torch.attention import ( |
|
|
|
MultiHeadAttention, |
|
|
|
EntityEmbeddings, |
|
|
|
|
|
|
assert masks_2[0, 1] == 0 if i % 2 == 0 else 1 |
|
|
|
|
|
|
|
|
|
|
|
def test_simple_transformer_training(): |
|
|
|
def test_predict_closest_training(): |
|
|
|
np.random.seed(1336) |
|
|
|
torch.manual_seed(1336) |
|
|
|
size, n_k, = 3, 5 |
|
|
|
|
|
|
l_layer = linear_layer(embedding_size, size) |
|
|
|
optimizer = torch.optim.Adam( |
|
|
|
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001 |
|
|
|
list(entity_embeddings.parameters()) |
|
|
|
+ list(transformer.parameters()) |
|
|
|
+ list(l_layer.parameters()), |
|
|
|
lr=0.001, |
|
|
|
weight_decay=1e-6, |
|
|
|
) |
|
|
|
batch_size = 200 |
|
|
|
for _ in range(200): |
|
|
|
|
|
|
error.backward() |
|
|
|
optimizer.step() |
|
|
|
assert error.item() < 0.02 |
|
|
|
|
|
|
|
|
|
|
|
def test_predict_minimum_training(): |
|
|
|
# of 5 numbers, predict index of min |
|
|
|
np.random.seed(1336) |
|
|
|
torch.manual_seed(1336) |
|
|
|
n_k = 5 |
|
|
|
size = n_k + 1 |
|
|
|
embedding_size = 64 |
|
|
|
entity_embeddings = EntityEmbeddings( |
|
|
|
size, [size], embedding_size, [n_k], concat_self=False |
|
|
|
) |
|
|
|
transformer = ResidualSelfAttention(embedding_size) |
|
|
|
l_layer = LinearEncoder(embedding_size, 2, n_k) |
|
|
|
loss = torch.nn.CrossEntropyLoss() |
|
|
|
optimizer = torch.optim.Adam( |
|
|
|
list(entity_embeddings.parameters()) |
|
|
|
+ list(transformer.parameters()) |
|
|
|
+ list(l_layer.parameters()), |
|
|
|
lr=0.001, |
|
|
|
weight_decay=1e-6, |
|
|
|
) |
|
|
|
|
|
|
|
batch_size = 200 |
|
|
|
onehots = ModelUtils.actions_to_onehot(torch.range(0, n_k - 1).unsqueeze(1), [n_k])[ |
|
|
|
0 |
|
|
|
] |
|
|
|
onehots = onehots.expand((batch_size, -1, -1)) |
|
|
|
losses = [] |
|
|
|
for _ in range(400): |
|
|
|
num = np.random.randint(0, n_k) |
|
|
|
inp = torch.rand((batch_size, num + 1, 1)) |
|
|
|
with torch.no_grad(): |
|
|
|
# create the target : The minimum |
|
|
|
argmin = torch.argmin(inp, dim=1) |
|
|
|
argmin = argmin.squeeze() |
|
|
|
argmin = argmin.detach() |
|
|
|
sliced_oh = onehots[:, : num + 1] |
|
|
|
inp = torch.cat([inp, sliced_oh], dim=2) |
|
|
|
|
|
|
|
embeddings = entity_embeddings(inp, [inp]) |
|
|
|
masks = EntityEmbeddings.get_masks([inp]) |
|
|
|
prediction = transformer(embeddings, masks) |
|
|
|
prediction = l_layer(prediction) |
|
|
|
ce = loss(prediction, argmin) |
|
|
|
losses.append(ce.item()) |
|
|
|
print(ce.item()) |
|
|
|
optimizer.zero_grad() |
|
|
|
ce.backward() |
|
|
|
optimizer.step() |
|
|
|
assert np.array(losses[-20:]).mean() < 0.1 |