|
|
|
|
|
|
super().__init__() |
|
|
|
self.layers = [] |
|
|
|
self.goal_encoders = [] |
|
|
|
prev_size = input_size |
|
|
|
for i in range(num_layers): |
|
|
|
prev_size = input_size + goal_size |
|
|
|
for _ in range(num_layers): |
|
|
|
self.layers.append( |
|
|
|
linear_layer( |
|
|
|
prev_size, |
|
|
|
|
|
|
) |
|
|
|
) |
|
|
|
self.goal_encoders.append(LinearEncoder(goal_size, 2, hidden_size, final_activation=False)) |
|
|
|
self.goal_encoders.append( |
|
|
|
LinearEncoder(goal_size, 2, hidden_size, final_activation=True) |
|
|
|
) |
|
|
|
self.layers = torch.nn.ModuleList(self.layers) |
|
|
|
self.goal_encoders = torch.nn.ModuleList(self.goal_encoders) |
|
|
|
activation = input_tensor |
|
|
|
activation = torch.cat([input_tensor, goal_tensor], dim=-1) |
|
|
|
activation = layer(activation) * self.goal_encoders[idx//2](goal_tensor) |
|
|
|
activation = layer(activation) + self.goal_encoders[idx // 2]( |
|
|
|
goal_tensor |
|
|
|
) |
|
|
|
return activation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HyperNetwork(prev_size, hidden_size, goal_size, 2, hidden_size) |
|
|
|
) |
|
|
|
self.layers.append(Swish()) |
|
|
|
self.layers = torch.nn.ModuleList(self.layers) |
|
|
|
prev_size = hidden_size |
|
|
|
|
|
|
|
def forward( |
|
|
|