|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer |
|
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
|
|
|
|
|
|
|
def __init__(self, embedding_size: int, num_heads: int): |
|
|
|
super().__init__() |
|
|
|
self.n_heads, self.embedding_size = num_heads, embedding_size |
|
|
|
self.head_size: int = self.embedding_size // self.n_heads |
|
|
|
self.n_heads = num_heads |
|
|
|
self.head_size: int = embedding_size // self.n_heads |
|
|
|
self.embedding_size: int = self.head_size * self.n_heads |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
|
|
|
|
class EntityEmbeddings(torch.nn.Module): |
|
|
|
""" |
|
|
|
A module used to embed entities before passing them to a self-attention block. |
|
|
|
Used in conjunction with ResidualSelfAttention to encode information about a self |
|
|
|
and additional entities. Can also concatenate self to entities for ego-centric self- |
|
|
|
attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
|
|
|
entity_num_max_elements: Optional[List[int]] = None, |
|
|
|
concat_self: bool = True, |
|
|
|
): |
|
|
|
""" |
|
|
|
Constructs an EntityEmbeddings module. |
|
|
|
:param x_self_size: Size of "self" entity. |
|
|
|
:param entity_sizes: List of sizes for other entities. Should be of length |
|
|
|
equivalent to the number of entities. |
|
|
|
:param embedding_size: Embedding size for entity encoders. |
|
|
|
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted. |
|
|
|
Needs to be assigned in order for model to be exportable to ONNX and Barracuda. |
|
|
|
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric |
|
|
|
self-attention. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.self_size: int = x_self_size |
|
|
|
self.entity_sizes: List[int] = entity_sizes |
|
|
|
|
|
|
self_and_ent: List[torch.Tensor] = [] |
|
|
|
for num_entities, ent in zip(self.entity_num_max_elements, entities): |
|
|
|
if num_entities < 0: |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
num_entities = ent.shape[1] |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
|
|
|
return key_masks |
|
|
|
|
|
|
|
|
|
|
|
class SmallestAttention(torch.nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
x_self_size: int, |
|
|
|
entities_sizes: List[int], |
|
|
|
embedding_size: int, |
|
|
|
output_size: Optional[int] = None, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.self_size = x_self_size |
|
|
|
self.entities_sizes = entities_sizes |
|
|
|
self.entities_num_max_elements: Optional[List[int]] = None |
|
|
|
self.ent_encoders = torch.nn.ModuleList( |
|
|
|
[ |
|
|
|
LinearEncoder(self.self_size + ent_size, 2, embedding_size) |
|
|
|
# LinearEncoder(self.self_size + ent_size, 3, embedding_size) |
|
|
|
# LinearEncoder(self.self_size + ent_size, 1, embedding_size) |
|
|
|
for ent_size in self.entities_sizes |
|
|
|
] |
|
|
|
) |
|
|
|
self.importance_layer = LinearEncoder(embedding_size, 1, 1) |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
x_self: torch.Tensor, |
|
|
|
entities: List[torch.Tensor], |
|
|
|
key_masks: List[torch.Tensor], |
|
|
|
) -> torch.Tensor: |
|
|
|
# Gather the maximum number of entities information |
|
|
|
if self.entities_num_max_elements is None: |
|
|
|
self.entities_num_max_elements = [] |
|
|
|
for ent in entities: |
|
|
|
self.entities_num_max_elements.append(ent.shape[1]) |
|
|
|
# Concatenate all observations with self |
|
|
|
self_and_ent: List[torch.Tensor] = [] |
|
|
|
for num_entities, ent in zip(self.entities_num_max_elements, entities): |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
# .repeat( |
|
|
|
# 1, num_entities, 1 |
|
|
|
# ) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) |
|
|
|
# Generate the tensor that will serve as query, key and value to self attention |
|
|
|
qkv = torch.cat( |
|
|
|
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], |
|
|
|
dim=1, |
|
|
|
) |
|
|
|
mask = torch.cat(key_masks, dim=1) |
|
|
|
# Feed to self attention |
|
|
|
importance = self.importance_layer(qkv) + mask.unsqueeze(2) * -1e6 |
|
|
|
importance = torch.softmax(importance, dim=1) |
|
|
|
weighted_qkv = qkv * importance |
|
|
|
|
|
|
|
output = torch.sum(weighted_qkv, dim=1) |
|
|
|
output = torch.cat([output, x_self], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|
|
|
multi head self attention to encode information about a "Self" and a list of |
|
|
|
relevant "Entities". |
|
|
|
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used |
|
|
|
with an EntityEmbeddings module, to apply multi head self attention to encode information |
|
|
|
about a "Self" and a list of relevant "Entities". |
|
|
|
""" |
|
|
|
|
|
|
|
EPSILON = 1e-7 |
|
|
|
|
|
|
entity_num_max_elements: Optional[List[int]] = None, |
|
|
|
num_heads: int = 4, |
|
|
|
): |
|
|
|
""" |
|
|
|
Constructs a ResidualSelfAttention module. |
|
|
|
:param embedding_size: Embedding sizee for attention mechanism and |
|
|
|
Q, K, V encoders. |
|
|
|
:param entity_num_max_elements: A List of ints representing the maximum number |
|
|
|
of elements in an entity sequence. Should be of length num_entities. Pass None to |
|
|
|
not restrict the number of elements; however, this will make the module |
|
|
|
unexportable to ONNX/Barracuda. |
|
|
|
:param num_heads: Number of heads for Multi Head Self-Attention |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.max_num_ent: Optional[int] = None |
|
|
|
if entity_num_max_elements is not None: |
|
|
|
|
|
|
value = self.fc_v(inp) # (b, n_k, emb) |
|
|
|
|
|
|
|
# Only use max num if provided |
|
|
|
num_ent = self.max_num_ent if self.max_num_ent is not None else inp.shape[1] |
|
|
|
if self.max_num_ent is not None: |
|
|
|
num_ent = self.max_num_ent |
|
|
|
else: |
|
|
|
num_ent = inp.shape[1] |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
|
|
|
|
output, _ = self.attention(query, key, value, num_ent, num_ent, mask) |
|
|
|
# Residual |
|
|
|