|
|
|
|
|
|
kernel_gain: float = 1.0, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.goal_encoder = LinearEncoder(goal_size, 2, hidden_size) |
|
|
|
self.layers = [ |
|
|
|
linear_layer( |
|
|
|
input_size, |
|
|
|
hidden_size, |
|
|
|
kernel_init=kernel_init, |
|
|
|
kernel_gain=kernel_gain, |
|
|
|
) |
|
|
|
] |
|
|
|
self.layers.append(Swish()) |
|
|
|
for _ in range(num_layers - 1): |
|
|
|
self.layers = [] |
|
|
|
self.goal_encoders = [] |
|
|
|
prev_size = input_size |
|
|
|
for i in range(num_layers): |
|
|
|
hidden_size, |
|
|
|
prev_size, |
|
|
|
self.goal_encoders.append(LinearEncoder(goal_size, 2, hidden_size, final_activation=False)) |
|
|
|
prev_size = hidden_size |
|
|
|
goal_activation = self.goal_encoder(goal_tensor) |
|
|
|
for idx, layer in enumerate(self.layers): |
|
|
|
if isinstance(layer, Swish): |
|
|
|
activation = layer(activation) |
|
|
|
else: |
|
|
|
activation = layer(activation) * self.goal_encoders[idx//2](goal_tensor) |
|
|
|
return activation |
|
|
|
|
|
|
|
|
|
|
|
class HyperEncoder(torch.nn.Module): |
|
|
|
""" |
|
|
|
Linear layers. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
input_size: int, |
|
|
|
goal_size: int, |
|
|
|
num_layers: int, |
|
|
|
hidden_size: int, |
|
|
|
kernel_init: Initialization = Initialization.KaimingHeNormal, |
|
|
|
kernel_gain: float = 1.0, |
|
|
|
num_hyper_layers: int = 1, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.layers = [] |
|
|
|
prev_size = input_size |
|
|
|
for i in range(num_layers): |
|
|
|
if i < num_layers - num_hyper_layers: |
|
|
|
self.layers.append( |
|
|
|
linear_layer( |
|
|
|
prev_size, |
|
|
|
hidden_size, |
|
|
|
kernel_init=kernel_init, |
|
|
|
kernel_gain=kernel_gain, |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
self.layers.append( |
|
|
|
HyperNetwork(prev_size, hidden_size, goal_size, 2, hidden_size) |
|
|
|
) |
|
|
|
self.layers.append(Swish()) |
|
|
|
prev_size = hidden_size |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor |
|
|
|
) -> torch.Tensor: |
|
|
|
activation = input_tensor |
|
|
|
activation = layer(activation) |
|
|
|
if layer is not Swish(): |
|
|
|
activation *= goal_activation |
|
|
|
if isinstance(layer, HyperNetwork): |
|
|
|
activation = layer(activation, goal_tensor) |
|
|
|
else: |
|
|
|
activation = layer(activation) |
|
|
|
return activation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_size: int, |
|
|
|
kernel_init: Initialization = Initialization.KaimingHeNormal, |
|
|
|
kernel_gain: float = 1.0, |
|
|
|
final_activation: bool = True, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.layers = [ |
|
|
|
|
|
|
) |
|
|
|
] |
|
|
|
self.layers.append(Swish()) |
|
|
|
for _ in range(num_layers - 1): |
|
|
|
for i in range(num_layers - 1): |
|
|
|
self.layers.append( |
|
|
|
linear_layer( |
|
|
|
hidden_size, |
|
|
|
|
|
|
) |
|
|
|
) |
|
|
|
self.layers.append(Swish()) |
|
|
|
if i < num_layers - 2 or final_activation: |
|
|
|
self.layers.append(Swish()) |
|
|
|
self.seq_layers = torch.nn.Sequential(*self.layers) |
|
|
|
|
|
|
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: |
|
|
|