|
|
|
|
|
|
super().__init__() |
|
|
|
n_channels = [16, 32, 32] # channel for each stack |
|
|
|
n_blocks = 2 # number of residual blocks |
|
|
|
self.layers = [] |
|
|
|
layers = [] |
|
|
|
self.layers.append( |
|
|
|
nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1) |
|
|
|
) |
|
|
|
self.layers.append(nn.MaxPool2d([3, 3], [2, 2])) |
|
|
|
layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)) |
|
|
|
layers.append(nn.MaxPool2d([3, 3], [2, 2])) |
|
|
|
self.layers.append(ResNetBlock(channel)) |
|
|
|
layers.append(ResNetBlock(channel)) |
|
|
|
self.layers.append(Swish()) |
|
|
|
layers.append(Swish()) |
|
|
|
self.dense = linear_layer( |
|
|
|
n_channels[-1] * height * width, |
|
|
|
final_hidden, |
|
|
|
|
|
|
self.sequential = nn.Sequential(*self.layers) |
|
|
|
hidden = visual_obs |
|
|
|
for layer in self.layers: |
|
|
|
hidden = layer(hidden) |
|
|
|
hidden = self.sequential(visual_obs) |
|
|
|
before_out = hidden.view(batch_size, -1) |
|
|
|
return torch.relu(self.dense(before_out)) |