|
|
|
|
|
|
forget_bias: float = 1.0, |
|
|
|
kernel_init: Initialization = Initialization.XavierGlorotUniform, |
|
|
|
bias_init: Initialization = Initialization.Zero, |
|
|
|
num_post_layers: int = 1, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.lstm = lstm_layer( |
|
|
|
|
|
|
bias_init, |
|
|
|
) |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.layers = [] |
|
|
|
for _ in range(num_post_layers): |
|
|
|
self.layers.append( |
|
|
|
linear_layer( |
|
|
|
input_size, |
|
|
|
hidden_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
) |
|
|
|
) |
|
|
|
self.layers.append(Swish()) |
|
|
|
self.seq_layers = torch.nn.Sequential(*self.layers) |
|
|
|
|
|
|
|
def forward(self, input_tensor, h0_c0): |
|
|
|
hidden = h0_c0 |
|
|
|
|
|
|
m = torch.max(m, h_half) |
|
|
|
out = torch.cat([m, other_half]) |
|
|
|
all_out.append(out) |
|
|
|
return torch.cat(all_out, dim=1), hidden |
|
|
|
full_out = self.seq_layers(torch.cat(all_out, dim=1)) |
|
|
|
return full_out, hidden |