|
|
|
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.layers = [] |
|
|
|
prev_size = input_size |
|
|
|
prev_size = input_size + goal_size |
|
|
|
for i in range(num_layers): |
|
|
|
if i < num_layers - num_hyper_layers: |
|
|
|
self.layers.append( |
|
|
|
|
|
|
def forward( |
|
|
|
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor |
|
|
|
) -> torch.Tensor: |
|
|
|
activation = input_tensor |
|
|
|
activation = torch.cat([input_tensor, goal_tensor], dim=-1) |
|
|
|
for layer in self.layers: |
|
|
|
if isinstance(layer, HyperNetwork): |
|
|
|
activation = layer(activation, goal_tensor) |
|
|
|