浏览代码

Refactor of attention (#4840)

Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
9689449f
共有 3 个文件被更改,包括 151 次插入167 次删除
  1. 85
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 231
      ml-agents/mlagents/trainers/torch/attention.py
  3. 2
      ml-agents/mlagents/trainers/torch/layers.py

85
ml-agents/mlagents/trainers/tests/torch/test_attention.py


import numpy as np
from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.attention import MultiHeadAttention, SimpleTransformer
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,
ResidualSelfAttention,
)
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12
n_h, emb_size = 4, 12
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size)
mha = MultiHeadAttention(emb_size, n_h)
query = torch.ones((b, n_q, q_size))
key = torch.ones((b, n_k, k_size))
value = torch.ones((b, n_k, v_size))
query = torch.ones((b, n_q, emb_size))
key = torch.ones((b, n_k, emb_size))
value = torch.ones((b, n_k, emb_size))
output, attention = mha.forward(query, key, value)
output, attention = mha.forward(query, key, value, n_q, n_k)
assert output.shape == (b, n_q, o_size)
assert output.shape == (b, n_q, emb_size)
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12
n_h, emb_size = 4, 12
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size)
mha = MultiHeadAttention(emb_size, n_h)
# create a key input with some keys all 0
query = torch.ones((b, n_q, emb_size))
key = torch.ones((b, n_k, emb_size))
value = torch.ones((b, n_k, emb_size))
# create a key input with some keys all 0
key = torch.ones((b, n_k, k_size))
mask = torch.zeros((b, n_k))
for i in range(n_k):
if i % 3 == 0:

query = torch.ones((b, n_q, q_size))
value = torch.ones((b, n_k, v_size))
_, attention = mha.forward(query, key, value, n_q, n_k, mask)
_, attention = mha.forward(query, key, value, mask)
for i in range(n_k):
if i % 3 == 0:
assert torch.sum(attention[:, :, :, i] ** 2) < epsilon

def test_multi_head_attention_training():
np.random.seed(1336)
torch.manual_seed(1336)
size, n_h, n_k, n_q = 3, 10, 5, 1
embedding_size = 64
mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size)
optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
batch_size = 200
point_range = 3
init_error = -1.0
for _ in range(50):
query = torch.rand((batch_size, n_q, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
value = key
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum((query - key) ** 2, dim=2)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
prediction, _ = mha.forward(query, key, value)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
if init_error == -1.0:
init_error = error.item()
else:
assert error.item() < init_error
print(error.item())
optimizer.zero_grad()
error.backward()
optimizer.step()
assert error.item() < 0.5
def test_zero_mask_layer():
batch_size, size = 10, 30

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = SimpleTransformer.get_masks([input_1, input_2])
masks = EntityEmbeddings.get_masks([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
transformer = SimpleTransformer(size, [size], embedding_size)
entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
transformer = ResidualSelfAttention(embedding_size, [n_k])
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001

init_error = -1.0
for _ in range(100):
for _ in range(250):
center = torch.rand((batch_size, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
with torch.no_grad():

target = torch.stack(target, dim=0)
target = target.detach()
masks = SimpleTransformer.get_masks([key])
prediction = transformer.forward(center, [key], masks)
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)

231
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.torch_utils import torch
from typing import Tuple, Optional, List
from mlagents.trainers.torch.layers import LinearEncoder
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer
class MultiHeadAttention(torch.nn.Module):

Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, key_size)
- key: of dimensions (batch_size, number_of_keys, key_size)
- value: of dimensions (batch_size, number_of_keys, value_size)
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
- The output: (batch_size, number_of_queries, output_size)
- The output: (batch_size, number_of_queries, embedding_size)
def __init__(
self,
query_size: int,
key_size: int,
value_size: int,
output_size: int,
num_heads: int,
embedding_size: int,
):
def __init__(self, embedding_size: int, num_heads: int):
self.n_heads, self.embedding_size = num_heads, embedding_size
self.output_size = output_size
self.fc_q = torch.nn.Linear(query_size, self.n_heads * self.embedding_size)
self.fc_k = torch.nn.Linear(key_size, self.n_heads * self.embedding_size)
self.fc_v = torch.nn.Linear(value_size, self.n_heads * self.embedding_size)
# self.fc_q = LinearEncoder(query_size, 2, self.n_heads * self.embedding_size)
# self.fc_k = LinearEncoder(key_size,2, self.n_heads * self.embedding_size)
# self.fc_v = LinearEncoder(value_size,2, self.n_heads * self.embedding_size)
self.fc_out = torch.nn.Linear(
self.n_heads * self.embedding_size, self.output_size
)
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads
self.embedding_size: int = self.head_size * self.n_heads
def forward(
self,

n_q: int,
n_k: int,
number_of_keys: int = -1,
number_of_queries: int = -1,
# This is to avoid using .size() when possible as Barracuda does not support
n_q = number_of_queries if number_of_queries != -1 else query.size(1)
n_k = number_of_keys if number_of_keys != -1 else key.size(1)
query = self.fc_q(query) # (b, n_q, h*d)
key = self.fc_k(key) # (b, n_k, h*d)
value = self.fc_v(value) # (b, n_k, h*d)
query = query.reshape(b, n_q, self.n_heads, self.embedding_size)
key = key.reshape(b, n_k, self.n_heads, self.embedding_size)
value = value.reshape(b, n_k, self.n_heads, self.embedding_size)
query = query.reshape(
b, n_q, self.n_heads, self.head_size
) # (b, n_q, h, emb / h)
key = key.reshape(b, n_k, self.n_heads, self.head_size) # (b, n_k, h, emb / h)
value = value.reshape(
b, n_k, self.n_heads, self.head_size
) # (b, n_k, h, emb / h)
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb)
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb / h)
key = key.permute([0, 2, 1, 3]) # (b, h, emb, n_k)
key = key.permute([0, 2, 1, 3]) # (b, h, emb / h, n_k)
key = key.permute([0, 1, 3, 2]) # (b, h, emb, n_k)
key = key.permute([0, 1, 3, 2]) # (b, h, emb / h, n_k)
qk = torch.matmul(query, key) # (b, h, n_q, n_k)

att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb)
value_attention = torch.matmul(att, value) # (b, h, n_q, emb)
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb / h)
value_attention = torch.matmul(att, value) # (b, h, n_q, emb / h)
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb)
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb / h)
b, n_q, self.n_heads * self.embedding_size
) # (b, n_q, h*emb)
b, n_q, self.embedding_size
) # (b, n_q, emb)
out = self.fc_out(value_attention) # (b, n_q, emb)
return out, att
return value_attention, att
class SimpleTransformer(torch.nn.Module):
class EntityEmbeddings(torch.nn.Module):
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses
multi head self attention to encode information about a "Self" and a list of
relevant "Entities".
EPISLON = 1e-7
entities_sizes: List[int],
entity_sizes: List[int],
entity_num_max_elements: List[int],
output_size: Optional[int] = None,
concat_self: bool = True,
self.self_size = x_self_size
self.entities_sizes = entities_sizes
self.entities_num_max_elements: Optional[List[int]] = None
self.self_size: int = x_self_size
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = entity_num_max_elements
self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:
self.self_size = 0
LinearEncoder(self.self_size + ent_size, 2, embedding_size)
for ent_size in self.entities_sizes
LinearEncoder(self.self_size + ent_size, 1, embedding_size)
for ent_size in self.entity_sizes
self.attention = MultiHeadAttention(
query_size=embedding_size,
key_size=embedding_size,
value_size=embedding_size,
output_size=embedding_size,
num_heads=4,
embedding_size=embedding_size,
)
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size)
if output_size is None:
output_size = embedding_size
self.x_self_residual_layer = LinearEncoder(
embedding_size + x_self_size, 1, output_size
)
self,
x_self: torch.Tensor,
entities: List[torch.Tensor],
key_masks: List[torch.Tensor],
) -> torch.Tensor:
# Gather the maximum number of entities information
if self.entities_num_max_elements is None:
self.entities_num_max_elements = []
for ent in entities:
self.entities_num_max_elements.append(ent.shape[1])
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entities_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
# .repeat(
# 1, num_entities, 1
# )
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
# Generate the tensor that will serve as query, key and value to self attention
qkv = torch.cat(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
if self.concat_self:
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
mask = torch.cat(key_masks, dim=1)
# Feed to self attention
max_num_ent = sum(self.entities_num_max_elements)
output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent)
# Residual
output = self.residual_layer(output) + qkv
# Average Pooling
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
output = numerator / denominator
# Residual between x_self and the output of the module
output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1))
return output
return encoded_entities
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:

for ent in observations
]
return key_masks
class ResidualSelfAttention(torch.nn.Module):
"""
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses
multi head self attention to encode information about a "Self" and a list of
relevant "Entities".
"""
EPSILON = 1e-7
def __init__(
self,
embedding_size: int,
entity_num_max_elements: List[int],
num_heads: int = 4,
):
super().__init__()
self.entity_num_max_elements: List[int] = entity_num_max_elements
self.max_num_ent = sum(entity_num_max_elements)
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size
)
self.fc_q = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_k = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_v = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_out = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor:
# Gather the maximum number of entities information
mask = torch.cat(key_masks, dim=1)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)
value = self.fc_v(inp) # (b, n_k, emb)
output, _ = self.attention(
query, key, value, self.max_num_ent, self.max_num_ent, mask
)
# Residual
output = self.fc_out(output) + inp
# Average Pooling
numerator = torch.sum(
output * (1 - mask).reshape(-1, self.max_num_ent, 1), dim=1
)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
# Residual between x_self and the output of the module
return output

2
ml-agents/mlagents/trainers/torch/layers.py


XavierGlorotUniform = 2
KaimingHeNormal = 3 # also known as Variance scaling
KaimingHeUniform = 4
Normal = 5
_init_methods = {

Initialization.KaimingHeNormal: torch.nn.init.kaiming_normal_,
Initialization.KaimingHeUniform: torch.nn.init.kaiming_uniform_,
Initialization.Normal: torch.nn.init.normal_,
}

正在加载...
取消
保存