浏览代码

Initial commit for multi head attention

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
96452986
共有 2 个文件被更改,包括 145 次插入0 次删除
  1. 79
      ml-agents/mlagents/trainers/tests/torch/test_layers.py
  2. 66
      ml-agents/mlagents/trainers/torch/layers.py

79
ml-agents/mlagents/trainers/tests/torch/test_layers.py


from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.torch.layers import (
Swish,

LSTM,
MultiHeadAttention,
)

# Hidden size should be half of memory_size
assert out.shape == (batch_size, seq_len, memory_size // 2)
assert mem.shape == (1, batch_size, memory_size)
def test_multi_head_attention_initialization():
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12
n_k, n_q, b = 13, 14, 15
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size)
query = torch.ones((b, n_q, q_size))
key = torch.ones((b, n_k, k_size))
value = torch.ones((b, n_k, v_size))
output, attention = mha.forward(query, key, value)
assert output.shape == (b, n_q, o_size)
assert attention.shape == (b, n_h, n_q, n_k)
def test_multi_head_attention_masking():
epsilon = 0.0001
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12
n_k, n_q, b = 13, 14, 15
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size)
# create a key input with some keys all 0
key = torch.ones((b, n_k, k_size))
for i in range(n_k):
if i % 3 == 0:
key[:, i, :] = 0
query = torch.ones((b, n_q, q_size))
value = torch.ones((b, n_k, v_size))
_, attention = mha.forward(query, key, value)
for i in range(n_k):
if i % 3 == 0:
assert torch.sum(attention[:, :, :, i] ** 2) < epsilon
else:
assert torch.sum(attention[:, :, :, i] ** 2) > epsilon
def test_multi_head_attention_training():
np.random.seed(1336)
torch.manual_seed(1336)
size, n_h, n_k, n_q = 3, 10, 5, 1
embedding_size = 64
mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size)
optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
batch_size = 200
point_range = 3
init_error = -1.0
for _ in range(50):
query = torch.rand((batch_size, n_q, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
value = key
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum((query - key) ** 2, dim=2)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
prediction, _ = mha.forward(query, key, value)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
if init_error == -1.0:
init_error = error.item()
else:
assert error.item() < init_error
print(error.item())
optimizer.zero_grad()
error.backward()
optimizer.step()
assert error.item() < 0.5

66
ml-agents/mlagents/trainers/torch/layers.py


lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
return lstm_out, output_mem
class MultiHeadAttention(torch.nn.Module):
NEG_INF = -1e6
def __init__(
self,
query_size: int,
key_size: int,
value_size: int,
output_size: int,
num_heads: int,
embedding_size: int,
):
super().__init__()
self.n_heads, self.embedding_size = num_heads, embedding_size
self.output_size = output_size
self.fc_q = torch.nn.Linear(query_size, self.n_heads * self.embedding_size)
self.fc_k = torch.nn.Linear(key_size, self.n_heads * self.embedding_size)
self.fc_v = torch.nn.Linear(value_size, self.n_heads * self.embedding_size)
self.fc_out = torch.nn.Linear(
self.n_heads * self.embedding_size, self.output_size
)
def forward(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
b, n_q, n_k = query.size(0), query.size(1), key.size(1)
# Create a key mask : Only 1 if all values are 0 # shape = (b, n_k)
key_mask = torch.sum(key ** 2, axis=2) < 0.01
key_mask = key_mask.reshape(b, 1, 1, n_k)
query = self.fc_q(query) # (b, n_q, h*d)
key = self.fc_k(key) # (b, n_k, h*d)
value = self.fc_v(value) # (b, n_k, h*d)
query = query.reshape(b, n_q, self.n_heads, self.embedding_size)
key = key.reshape(b, n_k, self.n_heads, self.embedding_size)
value = value.reshape(b, n_k, self.n_heads, self.embedding_size)
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb)
# The next few lines are equivalent to : key.permute([0, 2, 3, 1])
# This is a hack, ONNX will compress two permute operations and
# Barracuda will not like seeing `permute([0,2,3,1])`
key = key.permute([0, 2, 1, 3]) # (b, h, emb, n_k)
key -= 1
key += 1
key = key.permute([0, 1, 3, 2]) # (b, h, emb, n_k)
qk = torch.matmul(query, key) # (b, h, n_q, n_k)
qk = qk / (self.embedding_size ** 0.5) + key_mask * self.NEG_INF
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb)
value_attention = torch.matmul(att, value) # (b, h, n_q, emb)
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb)
value_attention = value_attention.reshape(
b, n_q, self.n_heads * self.embedding_size
) # (b, n_q, h*emb)
out = self.fc_out(value_attention) # (b, n_q, emb)
return out, att
正在加载...
取消
保存