|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer |
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
|
|
|
Takes as input to the forward method 3 tensors: |
|
|
|
- query: of dimensions (batch_size, number_of_queries, key_size) |
|
|
|
- key: of dimensions (batch_size, number_of_keys, key_size) |
|
|
|
- value: of dimensions (batch_size, number_of_keys, value_size) |
|
|
|
- query: of dimensions (batch_size, number_of_queries, embedding_size) |
|
|
|
- key: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
- value: of dimensions (batch_size, number_of_keys, embedding_size) |
|
|
|
- The output: (batch_size, number_of_queries, output_size) |
|
|
|
- The output: (batch_size, number_of_queries, embedding_size) |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
query_size: int, |
|
|
|
key_size: int, |
|
|
|
value_size: int, |
|
|
|
output_size: int, |
|
|
|
num_heads: int, |
|
|
|
embedding_size: int, |
|
|
|
): |
|
|
|
def __init__(self, embedding_size: int, num_heads: int): |
|
|
|
self.n_heads, self.embedding_size = num_heads, embedding_size |
|
|
|
self.output_size = output_size |
|
|
|
self.fc_q = torch.nn.Linear(query_size, self.n_heads * self.embedding_size) |
|
|
|
self.fc_k = torch.nn.Linear(key_size, self.n_heads * self.embedding_size) |
|
|
|
self.fc_v = torch.nn.Linear(value_size, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_q = LinearEncoder(query_size, 2, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_k = LinearEncoder(key_size,2, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_v = LinearEncoder(value_size,2, self.n_heads * self.embedding_size) |
|
|
|
self.fc_out = torch.nn.Linear( |
|
|
|
self.n_heads * self.embedding_size, self.output_size |
|
|
|
) |
|
|
|
self.n_heads = num_heads |
|
|
|
self.head_size: int = embedding_size // self.n_heads |
|
|
|
self.embedding_size: int = self.head_size * self.n_heads |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
n_q: int, |
|
|
|
n_k: int, |
|
|
|
number_of_keys: int = -1, |
|
|
|
number_of_queries: int = -1, |
|
|
|
# This is to avoid using .size() when possible as Barracuda does not support |
|
|
|
n_q = number_of_queries if number_of_queries != -1 else query.size(1) |
|
|
|
n_k = number_of_keys if number_of_keys != -1 else key.size(1) |
|
|
|
query = self.fc_q(query) # (b, n_q, h*d) |
|
|
|
key = self.fc_k(key) # (b, n_k, h*d) |
|
|
|
value = self.fc_v(value) # (b, n_k, h*d) |
|
|
|
|
|
|
|
query = query.reshape(b, n_q, self.n_heads, self.embedding_size) |
|
|
|
key = key.reshape(b, n_k, self.n_heads, self.embedding_size) |
|
|
|
value = value.reshape(b, n_k, self.n_heads, self.embedding_size) |
|
|
|
query = query.reshape( |
|
|
|
b, n_q, self.n_heads, self.head_size |
|
|
|
) # (b, n_q, h, emb / h) |
|
|
|
key = key.reshape(b, n_k, self.n_heads, self.head_size) # (b, n_k, h, emb / h) |
|
|
|
value = value.reshape( |
|
|
|
b, n_k, self.n_heads, self.head_size |
|
|
|
) # (b, n_k, h, emb / h) |
|
|
|
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb) |
|
|
|
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb / h) |
|
|
|
key = key.permute([0, 2, 1, 3]) # (b, h, emb, n_k) |
|
|
|
key = key.permute([0, 2, 1, 3]) # (b, h, emb / h, n_k) |
|
|
|
key = key.permute([0, 1, 3, 2]) # (b, h, emb, n_k) |
|
|
|
key = key.permute([0, 1, 3, 2]) # (b, h, emb / h, n_k) |
|
|
|
|
|
|
|
qk = torch.matmul(query, key) # (b, h, n_q, n_k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k) |
|
|
|
|
|
|
|
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb) |
|
|
|
value_attention = torch.matmul(att, value) # (b, h, n_q, emb) |
|
|
|
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb / h) |
|
|
|
value_attention = torch.matmul(att, value) # (b, h, n_q, emb / h) |
|
|
|
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb) |
|
|
|
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb / h) |
|
|
|
b, n_q, self.n_heads * self.embedding_size |
|
|
|
) # (b, n_q, h*emb) |
|
|
|
b, n_q, self.embedding_size |
|
|
|
) # (b, n_q, emb) |
|
|
|
out = self.fc_out(value_attention) # (b, n_q, emb) |
|
|
|
return out, att |
|
|
|
return value_attention, att |
|
|
|
class SimpleTransformer(torch.nn.Module): |
|
|
|
class EntityEmbeddings(torch.nn.Module): |
|
|
|
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|
|
|
multi head self attention to encode information about a "Self" and a list of |
|
|
|
relevant "Entities". |
|
|
|
EPISLON = 1e-7 |
|
|
|
|
|
|
|
entities_sizes: List[int], |
|
|
|
entity_sizes: List[int], |
|
|
|
entity_num_max_elements: List[int], |
|
|
|
output_size: Optional[int] = None, |
|
|
|
concat_self: bool = True, |
|
|
|
self.self_size = x_self_size |
|
|
|
self.entities_sizes = entities_sizes |
|
|
|
self.entities_num_max_elements: Optional[List[int]] = None |
|
|
|
self.self_size: int = x_self_size |
|
|
|
self.entity_sizes: List[int] = entity_sizes |
|
|
|
self.entity_num_max_elements: List[int] = entity_num_max_elements |
|
|
|
self.concat_self: bool = concat_self |
|
|
|
# If not concatenating self, input to encoder is just entity size |
|
|
|
if not concat_self: |
|
|
|
self.self_size = 0 |
|
|
|
LinearEncoder(self.self_size + ent_size, 2, embedding_size) |
|
|
|
for ent_size in self.entities_sizes |
|
|
|
LinearEncoder(self.self_size + ent_size, 1, embedding_size) |
|
|
|
for ent_size in self.entity_sizes |
|
|
|
self.attention = MultiHeadAttention( |
|
|
|
query_size=embedding_size, |
|
|
|
key_size=embedding_size, |
|
|
|
value_size=embedding_size, |
|
|
|
output_size=embedding_size, |
|
|
|
num_heads=4, |
|
|
|
embedding_size=embedding_size, |
|
|
|
) |
|
|
|
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size) |
|
|
|
if output_size is None: |
|
|
|
output_size = embedding_size |
|
|
|
self.x_self_residual_layer = LinearEncoder( |
|
|
|
embedding_size + x_self_size, 1, output_size |
|
|
|
) |
|
|
|
self, |
|
|
|
x_self: torch.Tensor, |
|
|
|
entities: List[torch.Tensor], |
|
|
|
key_masks: List[torch.Tensor], |
|
|
|
) -> torch.Tensor: |
|
|
|
# Gather the maximum number of entities information |
|
|
|
if self.entities_num_max_elements is None: |
|
|
|
self.entities_num_max_elements = [] |
|
|
|
for ent in entities: |
|
|
|
self.entities_num_max_elements.append(ent.shape[1]) |
|
|
|
# Concatenate all observations with self |
|
|
|
self_and_ent: List[torch.Tensor] = [] |
|
|
|
for num_entities, ent in zip(self.entities_num_max_elements, entities): |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
# .repeat( |
|
|
|
# 1, num_entities, 1 |
|
|
|
# ) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) |
|
|
|
# Generate the tensor that will serve as query, key and value to self attention |
|
|
|
qkv = torch.cat( |
|
|
|
self, x_self: torch.Tensor, entities: List[torch.Tensor] |
|
|
|
) -> Tuple[torch.Tensor, int]: |
|
|
|
if self.concat_self: |
|
|
|
# Concatenate all observations with self |
|
|
|
self_and_ent: List[torch.Tensor] = [] |
|
|
|
for num_entities, ent in zip(self.entity_num_max_elements, entities): |
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) |
|
|
|
else: |
|
|
|
self_and_ent = entities |
|
|
|
# Encode and concatenate entites |
|
|
|
encoded_entities = torch.cat( |
|
|
|
mask = torch.cat(key_masks, dim=1) |
|
|
|
# Feed to self attention |
|
|
|
max_num_ent = sum(self.entities_num_max_elements) |
|
|
|
output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent) |
|
|
|
# Residual |
|
|
|
output = self.residual_layer(output) + qkv |
|
|
|
# Average Pooling |
|
|
|
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1) |
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON |
|
|
|
output = numerator / denominator |
|
|
|
# Residual between x_self and the output of the module |
|
|
|
output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1)) |
|
|
|
return output |
|
|
|
return encoded_entities |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
|
|
|
|
for ent in observations |
|
|
|
] |
|
|
|
return key_masks |
|
|
|
|
|
|
|
|
|
|
|
class ResidualSelfAttention(torch.nn.Module): |
|
|
|
""" |
|
|
|
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|
|
|
multi head self attention to encode information about a "Self" and a list of |
|
|
|
relevant "Entities". |
|
|
|
""" |
|
|
|
|
|
|
|
EPSILON = 1e-7 |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
embedding_size: int, |
|
|
|
entity_num_max_elements: List[int], |
|
|
|
num_heads: int = 4, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.entity_num_max_elements: List[int] = entity_num_max_elements |
|
|
|
self.max_num_ent = sum(entity_num_max_elements) |
|
|
|
self.attention = MultiHeadAttention( |
|
|
|
num_heads=num_heads, embedding_size=embedding_size |
|
|
|
) |
|
|
|
|
|
|
|
self.fc_q = linear_layer( |
|
|
|
embedding_size, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
self.fc_k = linear_layer( |
|
|
|
embedding_size, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
self.fc_v = linear_layer( |
|
|
|
embedding_size, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
self.fc_out = linear_layer( |
|
|
|
embedding_size, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
# Gather the maximum number of entities information |
|
|
|
mask = torch.cat(key_masks, dim=1) |
|
|
|
# Feed to self attention |
|
|
|
query = self.fc_q(inp) # (b, n_q, emb) |
|
|
|
key = self.fc_k(inp) # (b, n_k, emb) |
|
|
|
value = self.fc_v(inp) # (b, n_k, emb) |
|
|
|
output, _ = self.attention( |
|
|
|
query, key, value, self.max_num_ent, self.max_num_ent, mask |
|
|
|
) |
|
|
|
# Residual |
|
|
|
output = self.fc_out(output) + inp |
|
|
|
# Average Pooling |
|
|
|
numerator = torch.sum( |
|
|
|
output * (1 - mask).reshape(-1, self.max_num_ent, 1), dim=1 |
|
|
|
) |
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON |
|
|
|
output = numerator / denominator |
|
|
|
# Residual between x_self and the output of the module |
|
|
|
return output |