|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HyperNetwork(nn.Module): |
|
|
|
def __init__(self, input_size, output_size, hyper_input_size, num_layers, layer_size): |
|
|
|
def __init__( |
|
|
|
self, input_size, output_size, hyper_input_size, num_layers, layer_size |
|
|
|
): |
|
|
|
layers = [linear_layer( |
|
|
|
hyper_input_size, |
|
|
|
layer_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
), Swish()] |
|
|
|
layers = [ |
|
|
|
linear_layer( |
|
|
|
hyper_input_size, |
|
|
|
layer_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
), |
|
|
|
Swish(), |
|
|
|
] |
|
|
|
for _ in range(num_layers - 1): |
|
|
|
layers.append( |
|
|
|
linear_layer( |
|
|
|
|
|
|
batch_size = input_activation.size(0) |
|
|
|
|
|
|
|
output_weights, output_bias = torch.split( |
|
|
|
flat_output_weights, |
|
|
|
self.input_size * self.output_size, |
|
|
|
dim=-1, |
|
|
|
flat_output_weights, self.input_size * self.output_size, dim=-1 |
|
|
|
output_weights = output_weights.view(batch_size, self.input_size, self.output_size) |
|
|
|
output_weights = output_weights.view( |
|
|
|
batch_size, self.input_size, self.output_size |
|
|
|
) |
|
|
|
print(output_weights.shape, output_bias.shape, input_activation.shape) |
|
|
|
output = torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) + output_bias |
|
|
|
print(output.shape) |
|
|
|
output = ( |
|
|
|
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) |
|
|
|
+ output_bias |
|
|
|
) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.input_size = input_size |
|
|
|
self.output_size = output_size |
|
|
|
self.streams_size = len(stream_names) |
|
|
|
self.hypernetwork = HyperNetwork(input_size, self.output_size * self.streams_size, goal_size, num_layers, layer_size) |
|
|
|
self.hypernetwork = HyperNetwork( |
|
|
|
input_size, |
|
|
|
self.output_size * self.streams_size, |
|
|
|
goal_size, |
|
|
|
num_layers, |
|
|
|
layer_size, |
|
|
|
) |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, hidden: torch.Tensor, goal: torch.Tensor |
|
|
|