|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
import abc |
|
|
|
import math |
|
|
|
from typing import Tuple |
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
|
|
layer_in_size = layer_size |
|
|
|
flat_output = linear_layer( |
|
|
|
layer_size, |
|
|
|
input_size * output_size + output_size, |
|
|
|
input_size * output_size, |
|
|
|
self.hypernet = torch.nn.Sequential(*layers, flat_output) |
|
|
|
|
|
|
|
# Re-initializing the weights of the last layer of the hypernetwork |
|
|
|
bound = math.sqrt(1 / (layer_size * self.input_size)) |
|
|
|
flat_output.weight.data.uniform_(-bound, bound) |
|
|
|
|
|
|
|
self.hypernet = torch.nn.Sequential(*layers, flat_output, LayerNorm()) |
|
|
|
|
|
|
|
# The hypernetwork will not generate the bias of the main network layer |
|
|
|
self.bias = torch.nn.Parameter(torch.zeros(output_size)) |
|
|
|
flat_output_weights = self.hypernet(hyper_input) |
|
|
|
batch_size = input_activation.size(0) |
|
|
|
output_weights = self.hypernet(hyper_input) |
|
|
|
output_weights, output_bias = torch.split( |
|
|
|
flat_output_weights, self.input_size * self.output_size, dim=-1 |
|
|
|
) |
|
|
|
|
|
|
|
output_weights = output_weights.view( |
|
|
|
batch_size, self.input_size, self.output_size |
|
|
|
) |
|
|
|
output_bias = output_bias.view(batch_size, self.output_size) |
|
|
|
output = ( |
|
|
|
output_weights = output_weights.view(-1, self.input_size, self.output_size) |
|
|
|
return ( |
|
|
|
+ output_bias |
|
|
|
+ self.bias |
|
|
|
return output |