浏览代码

Add predict minimum attention test (#4853)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
a02cf933
共有 1 个文件被更改,包括 59 次插入3 次删除
  1. 62
      ml-agents/mlagents/trainers/tests/torch/test_attention.py

62
ml-agents/mlagents/trainers/tests/torch/test_attention.py


from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.layers import linear_layer, LinearEncoder
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,

assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
def test_simple_transformer_training():
def test_predict_closest_training():
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5

l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
weight_decay=1e-6,
)
batch_size = 200
for _ in range(200):

error.backward()
optimizer.step()
assert error.item() < 0.02
def test_predict_minimum_training():
# of 5 numbers, predict index of min
np.random.seed(1336)
torch.manual_seed(1336)
n_k = 5
size = n_k + 1
embedding_size = 64
entity_embeddings = EntityEmbeddings(
size, [size], embedding_size, [n_k], concat_self=False
)
transformer = ResidualSelfAttention(embedding_size)
l_layer = LinearEncoder(embedding_size, 2, n_k)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
weight_decay=1e-6,
)
batch_size = 200
onehots = ModelUtils.actions_to_onehot(torch.range(0, n_k - 1).unsqueeze(1), [n_k])[
0
]
onehots = onehots.expand((batch_size, -1, -1))
losses = []
for _ in range(400):
num = np.random.randint(0, n_k)
inp = torch.rand((batch_size, num + 1, 1))
with torch.no_grad():
# create the target : The minimum
argmin = torch.argmin(inp, dim=1)
argmin = argmin.squeeze()
argmin = argmin.detach()
sliced_oh = onehots[:, : num + 1]
inp = torch.cat([inp, sliced_oh], dim=2)
embeddings = entity_embeddings(inp, [inp])
masks = EntityEmbeddings.get_masks([inp])
prediction = transformer(embeddings, masks)
prediction = l_layer(prediction)
ce = loss(prediction, argmin)
losses.append(ce.item())
print(ce.item())
optimizer.zero_grad()
ce.backward()
optimizer.step()
assert np.array(losses[-20:]).mean() < 0.1
正在加载...
取消
保存