|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer |
|
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
|
|
|
self, |
|
|
|
x_self_size: int, |
|
|
|
entity_sizes: List[int], |
|
|
|
entity_num_max_elements: List[int], |
|
|
|
entity_num_max_elements: Optional[List[int]], |
|
|
|
A module that generates embeddings for a variable number of entities given "self" |
|
|
|
encoding as well as a representation of each entity of each type. The expected |
|
|
|
input of the forward method will be a list of tensors: each element of the |
|
|
|
list corresponds to a type of entity, the dimension of each tensor is : |
|
|
|
[batch_size, max_num_entities, entity_size] |
|
|
|
:param x_self_size: The size of the self embedding that will be concatenated |
|
|
|
with the entities |
|
|
|
:param entity_sizes: The size of each entity type |
|
|
|
:param entity_num_max_elements: A list of maximum number of entities, must be |
|
|
|
the same length as the number of entity tensors that will be passed to the |
|
|
|
forward method. |
|
|
|
:param embedding_size: The size of the output embeddings |
|
|
|
:param concat_self: If true, the x_self will be concatenated with the entities |
|
|
|
before embedding |
|
|
|
Constructs an EntityEmbeddings module. |
|
|
|
:param x_self_size: Size of "self" entity. |
|
|
|
:param entity_sizes: List of sizes for other entities. Should be of length |
|
|
|
equivalent to the number of entities. |
|
|
|
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted. |
|
|
|
Needs to be assigned in order for model to be exportable to ONNX and Barracuda. |
|
|
|
:param embedding_size: Embedding size for entity encoders. |
|
|
|
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric |
|
|
|
self-attention. |
|
|
|
self.entity_num_max_elements: List[int] = entity_num_max_elements |
|
|
|
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes) |
|
|
|
if entity_num_max_elements is not None: |
|
|
|
self.entity_num_max_elements = entity_num_max_elements |
|
|
|
|
|
|
|
self.concat_self: bool = concat_self |
|
|
|
# If not concatenating self, input to encoder is just entity size |
|
|
|
if not concat_self: |
|
|
|
|
|
|
# Concatenate all observations with self |
|
|
|
self_and_ent: List[torch.Tensor] = [] |
|
|
|
for num_entities, ent in zip(self.entity_num_max_elements, entities): |
|
|
|
if num_entities < 0: |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
num_entities = ent.shape[1] |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualSelfAttention(torch.nn.Module): |
|
|
|
""" |
|
|
|
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used |
|
|
|
with an EntityEmbeddings module, to apply multi head self attention to encode information |
|
|
|
about a "Self" and a list of relevant "Entities". |
|
|
|
""" |
|
|
|
|
|
|
|
self, embedding_size: int, total_max_elements: int, num_heads: int = 4 |
|
|
|
self, |
|
|
|
embedding_size: int, |
|
|
|
entity_num_max_elements: Optional[int] = None, |
|
|
|
num_heads: int = 4, |
|
|
|
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|
|
|
multi head self attention to encode information about a "Self" and a list of |
|
|
|
relevant "Entities". |
|
|
|
:param embedding_size: The size of the embeddings that will be generated (should be |
|
|
|
dividable by the num_heads) |
|
|
|
:param total_max_elements: The maximum total number of entities that can be passed to |
|
|
|
the module |
|
|
|
:param num_heads: The number of heads of the attention module |
|
|
|
Constructs a ResidualSelfAttention module. |
|
|
|
:param embedding_size: Embedding sizee for attention mechanism and |
|
|
|
Q, K, V encoders. |
|
|
|
:param entity_num_max_elements: A List of ints representing the maximum number |
|
|
|
of elements in an entity sequence. Should be of length num_entities. Pass None to |
|
|
|
not restrict the number of elements; however, this will make the module |
|
|
|
unexportable to ONNX/Barracuda. |
|
|
|
:param num_heads: Number of heads for Multi Head Self-Attention |
|
|
|
self.max_num_ent = total_max_elements |
|
|
|
self.max_num_ent: Optional[int] = None |
|
|
|
if entity_num_max_elements is not None: |
|
|
|
self.max_num_ent = entity_num_max_elements |
|
|
|
|
|
|
|
self.attention = MultiHeadAttention( |
|
|
|
num_heads=num_heads, embedding_size=embedding_size |
|
|
|
) |
|
|
|
|
|
|
query = self.fc_q(inp) # (b, n_q, emb) |
|
|
|
key = self.fc_k(inp) # (b, n_k, emb) |
|
|
|
value = self.fc_v(inp) # (b, n_k, emb) |
|
|
|
output, _ = self.attention( |
|
|
|
query, key, value, self.max_num_ent, self.max_num_ent, mask |
|
|
|
) |
|
|
|
|
|
|
|
# Only use max num if provided |
|
|
|
if self.max_num_ent is not None: |
|
|
|
num_ent = self.max_num_ent |
|
|
|
else: |
|
|
|
num_ent = inp.shape[1] |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
raise UnityTrainerException( |
|
|
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
|
|
number of elements." |
|
|
|
) |
|
|
|
|
|
|
|
output, _ = self.attention(query, key, value, num_ent, num_ent, mask) |
|
|
|
numerator = torch.sum( |
|
|
|
output * (1 - mask).reshape(-1, self.max_num_ent, 1), dim=1 |
|
|
|
) |
|
|
|
numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1) |
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON |
|
|
|
output = numerator / denominator |
|
|
|
# Residual between x_self and the output of the module |