浏览代码

Use attention tests from master

/MLA-1734-demo-provider
Ervin Teng 4 年前
当前提交
c7054d76
共有 2 个文件被更改,包括 137 次插入95 次删除
  1. 71
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 161
      ml-agents/mlagents/trainers/torch/attention.py

71
ml-agents/mlagents/trainers/tests/torch/test_attention.py


import pytest
from mlagents.torch_utils import torch
import numpy as np

MultiHeadAttention,
EntityEmbeddings,
EntityEmbedding,
get_zero_entities_mask,
)

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = EntityEmbeddings.get_masks([input_1, input_2])
masks = get_zero_entities_mask([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
@pytest.mark.parametrize("mask_value", [0, 1])
def test_all_masking(mask_value):
# We make sure that a mask of all zeros or all ones will not trigger an error
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
weight_decay=1e-6,
)
batch_size = 20
for _ in range(5):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum(
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, key)
masks = [torch.ones_like(key[:, :, 0]) * mask_value]
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
optimizer.zero_grad()
error.backward()
optimizer.step()
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
transformer = ResidualSelfAttention(embedding_size, [n_k])
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())

target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
embeddings = entity_embeddings(center, key)
masks = get_zero_entities_mask([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))

n_k = 5
size = n_k + 1
embedding_size = 64
entity_embeddings = EntityEmbeddings(
size, [size], embedding_size, [n_k], concat_self=False
)
entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self
list(entity_embeddings.parameters())
list(entity_embedding.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,

sliced_oh = onehots[:, : num + 1]
inp = torch.cat([inp, sliced_oh], dim=2)
embeddings = entity_embeddings(inp, [inp])
masks = EntityEmbeddings.get_masks([inp])
embeddings = entity_embedding(inp, inp)
masks = get_zero_entities_mask([inp])
prediction = transformer(embeddings, masks)
prediction = l_layer(prediction)
ce = loss(prediction, argmin)

161
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.trainers.exception import UnityTrainerException
class MultiHeadAttention(torch.nn.Module):
def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]:
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks
class MultiHeadAttention(torch.nn.Module):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
"""
super().__init__()
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads

return value_attention, att
class EntityEmbeddings(torch.nn.Module):
class EntityEmbedding(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self

def __init__(
self,
x_self_size: int,
entity_sizes: List[int],
entity_size: int,
entity_num_max_elements: Optional[int],
entity_num_max_elements: Optional[List[int]] = None,
concat_self: bool = True,
Constructs an EntityEmbeddings module.
Constructs an EntityEmbedding module.
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param embedding_size: Embedding size for entity encoders.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param entity_size: Size of other entities.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
:param embedding_size: Embedding size for the entity encoder.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
self.self_size: int = x_self_size
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.self_size: int = 0
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
self.embedding_size = embedding_size
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.self_ent_encoder = LinearEncoder(
self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
)
self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:
self.self_size = 0
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
def add_self_embedding(self, size: int) -> None:
self.self_size = size
self.self_ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
self.embedding_norm = LayerNorm()
def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
if self.concat_self:
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
if self.self_size > 0:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
encoded_entities = self.embedding_norm(encoded_entities)
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.self_ent_encoder(entities)
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks
with an EntityEmbeddings module, to apply multi head self attention to encode information
with an EntityEmbedding module, to apply multi head self attention to encode information
about a "Self" and a list of relevant "Entities".
"""

self,
embedding_size: int,
entity_num_max_elements: Optional[List[int]] = None,
entity_num_max_elements: Optional[int] = None,
num_heads: int = 4,
):
"""

super().__init__()
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
_entity_num_max_elements = entity_num_max_elements
self.max_num_ent = sum(_entity_num_max_elements)
self.max_num_ent = entity_num_max_elements
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size

kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)

正在加载...
取消
保存