|
|
|
|
|
|
lstm_out, hidden_out = self.lstm(input_tensor, hidden) |
|
|
|
output_mem = torch.cat(hidden_out, dim=-1) |
|
|
|
return lstm_out, output_mem |
|
|
|
|
|
|
|
|
|
|
|
class HyperNetwork(torch.nn.Module): |
|
|
|
def __init__( |
|
|
|
self, input_size, output_size, hyper_input_size, num_layers, layer_size |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.input_size = input_size |
|
|
|
self.output_size = output_size |
|
|
|
|
|
|
|
layer_in_size = hyper_input_size |
|
|
|
layers = [] |
|
|
|
for _ in range(num_layers): |
|
|
|
layers.append( |
|
|
|
linear_layer( |
|
|
|
layer_in_size, |
|
|
|
layer_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
) |
|
|
|
) |
|
|
|
layers.append(Swish()) |
|
|
|
layer_in_size = layer_size |
|
|
|
flat_output = linear_layer( |
|
|
|
layer_size, |
|
|
|
input_size * output_size + output_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=0.1, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
) |
|
|
|
self.hypernet = torch.nn.Sequential(*layers, flat_output) |
|
|
|
|
|
|
|
def forward(self, input_activation, hyper_input): |
|
|
|
flat_output_weights = self.hypernet(hyper_input) |
|
|
|
batch_size = input_activation.size(0) |
|
|
|
|
|
|
|
output_weights, output_bias = torch.split( |
|
|
|
flat_output_weights, self.input_size * self.output_size, dim=-1 |
|
|
|
) |
|
|
|
|
|
|
|
output_weights = output_weights.view( |
|
|
|
batch_size, self.input_size, self.output_size |
|
|
|
) |
|
|
|
output_bias = output_bias.view(batch_size, self.output_size) |
|
|
|
output = ( |
|
|
|
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) |
|
|
|
+ output_bias |
|
|
|
) |
|
|
|
return output |