|
|
|
|
|
|
layers.append(ResNetBlock(channel)) |
|
|
|
last_channel = channel |
|
|
|
layers.append(Swish()) |
|
|
|
self.final_flat_size = n_channels[-1] * height * width |
|
|
|
n_channels[-1] * height * width, |
|
|
|
self.final_flat_size, |
|
|
|
output_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.41, # Use ReLU gain |
|
|
|
|
|
|
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: |
|
|
|
if not exporting_to_onnx.is_exporting(): |
|
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2]) |
|
|
|
batch_size = visual_obs.shape[0] |
|
|
|
before_out = hidden.reshape(batch_size, -1) |
|
|
|
before_out = hidden.reshape(-1, self.final_flat_size) |
|
|
|
return torch.relu(self.dense(before_out)) |