|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer, Initialization |
|
|
|
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer, Initialization, LayerNorm |
|
|
|
def grad_hook(mod, inp, out): |
|
|
|
print("") |
|
|
|
print(mod) |
|
|
|
print("-" * 10 + ' Incoming Gradients ' + '-' * 10) |
|
|
|
print("") |
|
|
|
print('Incoming Grad value: {}'.format(inp[0].data)) |
|
|
|
print("") |
|
|
|
print('Upstream Grad value: {}'.format(out[0].data)) |
|
|
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
""" |
|
|
|
|
|
|
super().__init__() |
|
|
|
self.n_heads, self.embedding_size = num_heads, embedding_size |
|
|
|
self.output_size = output_size |
|
|
|
#self.fc_q = linear_layer( |
|
|
|
# query_size, |
|
|
|
# self.n_heads * self.embedding_size, |
|
|
|
# kernel_init=Initialization.KaimingHeNormal, |
|
|
|
# kernel_gain=1.0, |
|
|
|
# ) |
|
|
|
#self.fc_k = linear_layer( |
|
|
|
# key_size, |
|
|
|
# self.n_heads * self.embedding_size, |
|
|
|
# kernel_init=Initialization.KaimingHeNormal, |
|
|
|
# kernel_gain=1.0, |
|
|
|
# ) |
|
|
|
#self.fc_v = linear_layer( |
|
|
|
# value_size, |
|
|
|
# self.n_heads * self.embedding_size, |
|
|
|
# kernel_init=Initialization.KaimingHeNormal, |
|
|
|
# kernel_gain=1.0, |
|
|
|
# ) |
|
|
|
|
|
|
|
# self.fc_q = LinearEncoder(query_size, 2, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_k = LinearEncoder(key_size,2, self.n_heads * self.embedding_size) |
|
|
|
# self.fc_v = LinearEncoder(value_size,2, self.n_heads * self.embedding_size) |
|
|
|
|
|
|
) # (b, n_q, h*emb) |
|
|
|
|
|
|
|
out = self.fc_out(value_attention) # (b, n_q, emb) |
|
|
|
#if out.requires_grad: |
|
|
|
# out.register_hook(lambda x: print(x)) |
|
|
|
#out = self.out_norm(out) |
|
|
|
return out, att |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# LinearEncoder(self.self_size + ent_size, 2, embedding_size) |
|
|
|
# from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf |
|
|
|
# linear_layer(self.self_size + ent_size, embedding_size, Initialization.Normal, kernel_gain=1 / (self.self_size + ent_size) ** 0.5) |
|
|
|
LinearEncoder(self.self_size + ent_size, 1, embedding_size) |
|
|
|
LinearEncoder(self.self_size + ent_size, 1, embedding_size, layer_norm=True) |
|
|
|
for ent_size in self.entities_sizes |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
num_heads=4, |
|
|
|
embedding_size=embedding_size, |
|
|
|
) |
|
|
|
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size) |
|
|
|
#self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size) |
|
|
|
self.res_norm = torch.nn.LayerNorm(embedding_size, elementwise_affine=True) |
|
|
|
if output_size is None: |
|
|
|
output_size = embedding_size |
|
|
|
|
|
|
|
|
|
|
max_num_ent = sum(self.entities_num_max_elements) |
|
|
|
output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent) |
|
|
|
# Residual |
|
|
|
output = self.residual_layer(output) + qkv |
|
|
|
#output = self.residual_layer(output) + qkv |
|
|
|
#output += qkv |
|
|
|
output = self.res_norm(output + qkv) |
|
|
|
#output = self.res_norm(output) |
|
|
|
output = torch.cat([output, x_self], dim=1) |
|
|
|
#output = torch.cat([output, x_self], dim=1) |
|
|
|
return output |
|
|
|
|
|
|
|
@staticmethod |
|
|
|