浏览代码

Revert "Use attention tests from master"

This reverts commit 78e052be8f36381bb6857817ff0f505716be83b9.
/MLA-1734-demo-provider
Ervin Teng 4 年前
当前提交
da6a55a0
共有 2 个文件被更改,包括 95 次插入137 次删除
  1. 71
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 161
      ml-agents/mlagents/trainers/torch/attention.py

71
ml-agents/mlagents/trainers/tests/torch/test_attention.py


import pytest
from mlagents.torch_utils import torch
import numpy as np

MultiHeadAttention,
EntityEmbedding,
EntityEmbeddings,
get_zero_entities_mask,
)

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = get_zero_entities_mask([input_1, input_2])
masks = EntityEmbeddings.get_masks([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
@pytest.mark.parametrize("mask_value", [0, 1])
def test_all_masking(mask_value):
# We make sure that a mask of all zeros or all ones will not trigger an error
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
weight_decay=1e-6,
)
batch_size = 20
for _ in range(5):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum(
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, key)
masks = [torch.ones_like(key[:, :, 0]) * mask_value]
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
optimizer.zero_grad()
error.backward()
optimizer.step()
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
transformer = ResidualSelfAttention(embedding_size, [n_k])
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())

target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, key)
masks = get_zero_entities_mask([key])
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))

n_k = 5
size = n_k + 1
embedding_size = 64
entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self
entity_embeddings = EntityEmbeddings(
size, [size], embedding_size, [n_k], concat_self=False
)
list(entity_embedding.parameters())
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,

sliced_oh = onehots[:, : num + 1]
inp = torch.cat([inp, sliced_oh], dim=2)
embeddings = entity_embedding(inp, inp)
masks = get_zero_entities_mask([inp])
embeddings = entity_embeddings(inp, [inp])
masks = EntityEmbeddings.get_masks([inp])
prediction = transformer(embeddings, masks)
prediction = l_layer(prediction)
ce = loss(prediction, argmin)

161
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.trainers.exception import UnityTrainerException
def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]:
class MultiHeadAttention(torch.nn.Module):
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks
class MultiHeadAttention(torch.nn.Module):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
"""
super().__init__()
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads

return value_attention, att
class EntityEmbedding(torch.nn.Module):
class EntityEmbeddings(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self

def __init__(
self,
entity_size: int,
entity_num_max_elements: Optional[int],
x_self_size: int,
entity_sizes: List[int],
entity_num_max_elements: Optional[List[int]] = None,
concat_self: bool = True,
Constructs an EntityEmbedding module.
Constructs an EntityEmbeddings module.
:param entity_size: Size of other entities.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param embedding_size: Embedding size for entity encoders.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param embedding_size: Embedding size for the entity encoder.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
self.self_size: int = 0
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
self.self_size: int = x_self_size
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.embedding_size = embedding_size
self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:
self.self_size = 0
self.self_ent_encoder = LinearEncoder(
self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
def add_self_embedding(self, size: int) -> None:
self.self_size = size
self.self_ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
if self.self_size > 0:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
if self.concat_self:
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.self_ent_encoder(entities)
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
encoded_entities = self.embedding_norm(encoded_entities)
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks
with an EntityEmbedding module, to apply multi head self attention to encode information
with an EntityEmbeddings module, to apply multi head self attention to encode information
about a "Self" and a list of relevant "Entities".
"""

self,
embedding_size: int,
entity_num_max_elements: Optional[int] = None,
entity_num_max_elements: Optional[List[int]] = None,
num_heads: int = 4,
):
"""

super().__init__()
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
self.max_num_ent = entity_num_max_elements
_entity_num_max_elements = entity_num_max_elements
self.max_num_ent = sum(_entity_num_max_elements)
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size

kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)

正在加载...
取消
保存