浏览代码

fix tests

/develop/singular-embeddings
Andrew Cohen 4 年前
当前提交
6dafe05c
共有 1 个文件被更改,包括 5 次插入5 次删除
  1. 10
      ml-agents/mlagents/trainers/tests/torch/test_attention.py

10
ml-agents/mlagents/trainers/tests/torch/test_attention.py


from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,
EntityEmbedding,
ResidualSelfAttention,
)

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = EntityEmbeddings.get_masks([input_1, input_2])
masks = ResidualSelfAttention.get_masks([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(

target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
embeddings = entity_embeddings(center, key)
masks = ResidualSelfAttention.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))

正在加载...
取消
保存