|
|
|
|
|
|
lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first) |
|
|
|
# Add forget_bias to forget gate bias |
|
|
|
for name, param in lstm.named_parameters(): |
|
|
|
# Each weight and bias is a concatenation of 4 matrices |
|
|
|
_init_methods[kernel_init](param.data) |
|
|
|
elif "bias" in name: |
|
|
|
_init_methods[bias_init](param.data) |
|
|
|
param.data[hidden_size : 2 * hidden_size].add_(forget_bias) |
|
|
|
for idx in range(4): |
|
|
|
block_size = param.shape[0] // 4 |
|
|
|
_init_methods[kernel_init]( |
|
|
|
param.data[idx * block_size : (idx + 1) * block_size] |
|
|
|
) |
|
|
|
if "bias" in name: |
|
|
|
for idx in range(4): |
|
|
|
block_size = param.shape[0] // 4 |
|
|
|
_init_methods[bias_init]( |
|
|
|
param.data[idx * block_size : (idx + 1) * block_size] |
|
|
|
) |
|
|
|
if idx == 1: |
|
|
|
param.data[idx * block_size : (idx + 1) * block_size].add_( |
|
|
|
forget_bias |
|
|
|
) |
|
|
|
return lstm |