|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
input_size: int, |
|
|
|
hidden_size: int, |
|
|
|
memory_size: int, |
|
|
|
num_layers: int = 1, |
|
|
|
batch_first: bool = True, |
|
|
|
forget_bias: float = 1.0, |
|
|
|
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.hidden_size = memory_size // 2 |
|
|
|
hidden_size, |
|
|
|
self.hidden_size, |
|
|
|
num_layers, |
|
|
|
batch_first, |
|
|
|
forget_bias, |
|
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
|
hidden_size, |
|
|
|
hidden_size, |
|
|
|
self.hidden_size, |
|
|
|
self.hidden_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
) |
|
|
|
|
|
|
dim=-1, |
|
|
|
) |
|
|
|
hidden = (h0, c0) |
|
|
|
all_c = [] |
|
|
|
for t in range(h_half.shape[1]): |
|
|
|
h_half_subt = h_half[:, t : t + 1, :] |
|
|
|
m = AMRLMax.PassthroughMax.apply(m, h_half_subt) |
|
|
|
all_c.append(m) |
|
|
|
concat_c = torch.cat(all_c, dim=1) |
|
|
|
# Append incoming m to LSTM hidden outs |
|
|
|
cat_m = torch.cat([m, h_half], dim=1) |
|
|
|
concat_c, _ = torch.cummax(cat_m, dim=1) |
|
|
|
# Remove first element, which is just m |
|
|
|
concat_c = concat_c[:, 1:, :] |
|
|
|
concat_out = torch.cat([concat_c, other_half], dim=-1) |
|
|
|
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size])) |
|
|
|
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size]) |
|
|
|