|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer, Initialization, LayerNorm |
|
|
|
from mlagents.trainers.torch.layers import ( |
|
|
|
LinearEncoder, |
|
|
|
linear_layer, |
|
|
|
Initialization, |
|
|
|
LayerNorm, |
|
|
|
) |
|
|
|
|
|
|
|
print("-" * 10 + ' Incoming Gradients ' + '-' * 10) |
|
|
|
print("-" * 10 + " Incoming Gradients " + "-" * 10) |
|
|
|
print('Incoming Grad value: {}'.format(inp[0].data)) |
|
|
|
print("Incoming Grad value: {}".format(inp[0].data)) |
|
|
|
print('Upstream Grad value: {}'.format(out[0].data)) |
|
|
|
print("Upstream Grad value: {}".format(out[0].data)) |
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
NEG_INF = -1e6 |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
num_heads: int, |
|
|
|
embedding_size: int, |
|
|
|
): |
|
|
|
def __init__(self, num_heads: int, embedding_size: int): |
|
|
|
# self.fc_q = torch.nn.linear(query_size, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_k = torch.nn.linear(key_size, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_v = torch.nn.linear(value_size, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_q = torch.nn.linear(query_size, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_k = torch.nn.linear(key_size, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_v = torch.nn.linear(value_size, self.n_heads * self.embedding_size) |
|
|
|
#self.fc_q = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
#self.fc_k = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
#self.fc_v = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
# self.fc_q = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
# self.fc_k = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
# self.fc_v = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
#self.fc_qk = torch.nn.Linear(self.embedding_size, 2 * self.embedding_size) |
|
|
|
# self.fc_qk = torch.nn.Linear(self.embedding_size, 2 * self.embedding_size) |
|
|
|
|
|
|
|
self.fc_out = torch.nn.Linear( |
|
|
|
self.embedding_size, self.embedding_size |
|
|
|
) |
|
|
|
self.fc_out = torch.nn.Linear(self.embedding_size, self.embedding_size) |
|
|
|
#for layer in [self.fc_qk, self.fc_v, self.fc_out]: |
|
|
|
torch.torch.nn.init.normal_(layer.weight, std = (0.125 / embedding_size ) ** 0.5) |
|
|
|
# for layer in [self.fc_qk, self.fc_v, self.fc_out]: |
|
|
|
torch.torch.nn.init.normal_( |
|
|
|
layer.weight, std=(0.125 / embedding_size) ** 0.5 |
|
|
|
) |
|
|
|
self.embedding_norm = torch.nn.LayerNorm(embedding_size) |
|
|
|
|
|
|
|
def forward( |
|
|
|
|
|
|
n_q = number_of_queries if number_of_queries != -1 else query.size(1) |
|
|
|
n_k = number_of_keys if number_of_keys != -1 else key.size(1) |
|
|
|
|
|
|
|
inp = self.embedding_norm(inp) |
|
|
|
inp = self.embedding_norm(inp) |
|
|
|
#query = self.fc_q(inp) |
|
|
|
#qk = self.fc_qk(inp) |
|
|
|
#qk = qk.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads, 2) |
|
|
|
#query, key = torch.split(qk, 1, dim=-1) |
|
|
|
#query = torch.squeeze(query, dim=-1) |
|
|
|
#key = torch.squeeze(key, dim=-1) |
|
|
|
# query = self.fc_q(inp) |
|
|
|
# qk = self.fc_qk(inp) |
|
|
|
# qk = qk.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads, 2) |
|
|
|
# query, key = torch.split(qk, 1, dim=-1) |
|
|
|
# query = torch.squeeze(query, dim=-1) |
|
|
|
# key = torch.squeeze(key, dim=-1) |
|
|
|
#query = query.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
#key = key.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
#value = value.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
# query = query.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
# key = key.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
# value = value.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads) |
|
|
|
#query = self.fc_q(query) # (b, n_q, h*d) |
|
|
|
#key = self.fc_k(key) # (b, n_k, h*d) |
|
|
|
#value = self.fc_v(value) # (b, n_k, h*d) |
|
|
|
# query = self.fc_q(query) # (b, n_q, h*d) |
|
|
|
# key = self.fc_k(key) # (b, n_k, h*d) |
|
|
|
# value = self.fc_v(value) # (b, n_k, h*d) |
|
|
|
|
|
|
|
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb) |
|
|
|
# The next few lines are equivalent to : key.permute([0, 2, 3, 1]) |
|
|
|
|
|
|
b, n_q, self.embedding_size |
|
|
|
) # (b, n_q, h*emb) |
|
|
|
out = self.fc_out(value_attention) # (b, n_q, emb) |
|
|
|
#if out.requires_grad: |
|
|
|
# if out.requires_grad: |
|
|
|
#out = self.out_norm(out) |
|
|
|
# out = self.out_norm(out) |
|
|
|
return out, att |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EPISLON = 1e-7 |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
x_self_size: int, |
|
|
|
entities_sizes: List[int], |
|
|
|
embedding_size: int, |
|
|
|
self, x_self_size: int, entities_sizes: List[int], embedding_size: int |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.self_size = x_self_size |
|
|
|
|
|
|
# LinearEncoder(self.self_size + ent_size, 2, embedding_size) |
|
|
|
# from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf |
|
|
|
# linear_layer(self.self_size + ent_size, embedding_size, Initialization.Normal, kernel_gain=(.125 / (self.self_size + ent_size)) ** 0.5) |
|
|
|
#linear_layer(ent_size, embedding_size, Initialization.Normal, kernel_gain=(.125 / (self.self_size + ent_size)) ** 0.5) |
|
|
|
LinearEncoder(ent_size, 1, embedding_size, kernel_init=Initialization.Normal, kernel_gain=(.125 / embedding_size) ** 0.5) |
|
|
|
#LinearEncoder(self.self_size + ent_size, 1, embedding_size, layer_norm=False) |
|
|
|
# linear_layer(ent_size, embedding_size, Initialization.Normal, kernel_gain=(.125 / (self.self_size + ent_size)) ** 0.5) |
|
|
|
LinearEncoder( |
|
|
|
ent_size, |
|
|
|
1, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
# LinearEncoder(self.self_size + ent_size, 1, embedding_size, layer_norm=False) |
|
|
|
self.attention = MultiHeadAttention( |
|
|
|
num_heads=4, |
|
|
|
embedding_size=embedding_size, |
|
|
|
) |
|
|
|
#self.residual_layer = torch.nn.Linear( |
|
|
|
self.attention = MultiHeadAttention(num_heads=4, embedding_size=embedding_size) |
|
|
|
# self.residual_layer = torch.nn.Linear( |
|
|
|
#) |
|
|
|
#torch.torch.nn.init.normal_(self.residual_layer.weight, std = 0.125 * embedding_size ** -0.5) |
|
|
|
# ) |
|
|
|
# torch.torch.nn.init.normal_(self.residual_layer.weight, std = 0.125 * embedding_size ** -0.5) |
|
|
|
|
|
|
|
self.res_norm = torch.nn.LayerNorm(embedding_size) |
|
|
|
|
|
|
|
|
|
|
for ent in entities: |
|
|
|
self.entities_num_max_elements.append(ent.shape[1]) |
|
|
|
# Concatenate all observations with self |
|
|
|
#self_and_ent: List[torch.Tensor] = [] |
|
|
|
#for num_entities, ent in zip(self.entities_num_max_elements, entities): |
|
|
|
# self_and_ent: List[torch.Tensor] = [] |
|
|
|
# for num_entities, ent in zip(self.entities_num_max_elements, entities): |
|
|
|
# expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
# # .repeat( |
|
|
|
# # 1, num_entities, 1 |
|
|
|
|
|
|
# Feed to self attention |
|
|
|
max_num_ent = sum(self.entities_num_max_elements) |
|
|
|
output, _ = self.attention(qkv, mask, max_num_ent, max_num_ent) |
|
|
|
#residual 1 |
|
|
|
# residual 1 |
|
|
|
#residual 2 |
|
|
|
#output = self.residual_layer(output) + output #qkv |
|
|
|
# average pooling |
|
|
|
# residual 2 |
|
|
|
# output = self.residual_layer(output) + output #qkv |
|
|
|
# average pooling |
|
|
|
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1) |
|
|
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON |
|
|
|
output = numerator / denominator |
|
|
|