浏览代码

Allow setting maximum number of elements in self-attention to None (#4841)

* separate entity encoder and RSA

* clean up args in mha

* more cleanups

* fixed tests

* entity embeddings have no max option

* Add exceptions for variable export

* Fix test

* Add docstrings

Co-authored-by: Andrew Cohen <andrew.cohen@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
b7e6efa3
共有 2 个文件被更改,包括 63 次插入15 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 76
      ml-agents/mlagents/trainers/torch/attention.py

2
ml-agents/mlagents/trainers/tests/torch/test_attention.py


torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
transformer = ResidualSelfAttention(embedding_size, [n_k])
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(

76
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.torch_utils import torch
from typing import Tuple, Optional, List
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
from mlagents.trainers.exception import UnityTrainerException
class MultiHeadAttention(torch.nn.Module):

class EntityEmbeddings(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self
and additional entities. Can also concatenate self to entities for ego-centric self-
attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf.
"""
def __init__(

entity_num_max_elements: List[int],
entity_num_max_elements: Optional[List[int]] = None,
"""
Constructs an EntityEmbeddings module.
:param x_self_size: Size of "self" entity.
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param embedding_size: Embedding size for entity encoders.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
self-attention.
"""
self.entity_num_max_elements: List[int] = entity_num_max_elements
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements
self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:

# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))

class ResidualSelfAttention(torch.nn.Module):
"""
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses
multi head self attention to encode information about a "Self" and a list of
relevant "Entities".
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used
with an EntityEmbeddings module, to apply multi head self attention to encode information
about a "Self" and a list of relevant "Entities".
"""
EPSILON = 1e-7

embedding_size: int,
entity_num_max_elements: List[int],
entity_num_max_elements: Optional[List[int]] = None,
"""
Constructs a ResidualSelfAttention module.
:param embedding_size: Embedding sizee for attention mechanism and
Q, K, V encoders.
:param entity_num_max_elements: A List of ints representing the maximum number
of elements in an entity sequence. Should be of length num_entities. Pass None to
not restrict the number of elements; however, this will make the module
unexportable to ONNX/Barracuda.
:param num_heads: Number of heads for Multi Head Self-Attention
"""
self.entity_num_max_elements: List[int] = entity_num_max_elements
self.max_num_ent = sum(entity_num_max_elements)
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
_entity_num_max_elements = entity_num_max_elements
self.max_num_ent = sum(_entity_num_max_elements)
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size
)

query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)
value = self.fc_v(inp) # (b, n_k, emb)
output, _ = self.attention(
query, key, value, self.max_num_ent, self.max_num_ent, mask
)
# Only use max num if provided
if self.max_num_ent is not None:
num_ent = self.max_num_ent
else:
num_ent = inp.shape[1]
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
numerator = torch.sum(
output * (1 - mask).reshape(-1, self.max_num_ent, 1), dim=1
)
numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
# Residual between x_self and the output of the module
正在加载...
取消
保存