浏览代码

added a test to make sure that a mask of all zeros or all ones would not break backpropagation

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
f5ec393b
共有 1 个文件被更改,包括 47 次插入0 次删除
  1. 47
      ml-agents/mlagents/trainers/tests/torch/test_attention.py

47
ml-agents/mlagents/trainers/tests/torch/test_attention.py


import pytest
from mlagents.torch_utils import torch
import numpy as np

assert masks_1[0, 1] == 0 if i % 2 == 0 else 1
for i in masking_pattern_2:
assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
@pytest.mark.parametrize("mask_value", [0, 1])
def test_all_masking(mask_value):
# We make sure that a mask of all zeros or all ones will not trigger an error
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
weight_decay=1e-6,
)
batch_size = 20
for _ in range(5000):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum(
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, key)
masks = [torch.ones_like(key[:, :, 0]) * mask_value]
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
optimizer.zero_grad()
error.backward()
optimizer.step()
def test_predict_closest_training():

正在加载...
取消
保存