|
|
|
|
|
|
for _ in range(num_post_layers): |
|
|
|
self.layers.append( |
|
|
|
linear_layer( |
|
|
|
input_size, |
|
|
|
hidden_size, |
|
|
|
hidden_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
|
|
|
|
|
|
|
def forward(self, input_tensor, h0_c0): |
|
|
|
hidden = h0_c0 |
|
|
|
all_out = [] |
|
|
|
all_c = [] |
|
|
|
for t in range(input_tensor.shape[1]): |
|
|
|
out, hidden = self.lstm(input_tensor[:, t : t + 1, :], hidden) |
|
|
|
h_half, other_half = torch.split(out, self.hidden_size // 2, dim=-1) |
|
|
|
lstm_out, hidden = self.lstm(input_tensor, hidden) |
|
|
|
h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1) |
|
|
|
for t in range(h_half.shape[1]): |
|
|
|
h_half_subt = h_half[:, t : t + 1, :] |
|
|
|
m = h_half |
|
|
|
m = h_half_subt |
|
|
|
m = torch.max(m, h_half) |
|
|
|
out = torch.cat([m, other_half]) |
|
|
|
all_out.append(out) |
|
|
|
full_out = self.seq_layers(torch.cat(all_out, dim=1)) |
|
|
|
return full_out, hidden |
|
|
|
m = torch.max(m, h_half_subt) |
|
|
|
all_c.append(m) |
|
|
|
concat_c = torch.cat(all_c, dim=1) |
|
|
|
concat_out = torch.cat([concat_c, other_half], dim=-1) |
|
|
|
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size])) |
|
|
|
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size]) |
|
|
|
return concat_out, hidden |