|
|
|
|
|
|
if not exporting_to_onnx.is_exporting(): |
|
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2]) |
|
|
|
hidden = self.conv_layers(visual_obs) |
|
|
|
hidden = hidden.view([-1, self.final_flat]) |
|
|
|
hidden = hidden.reshape([-1, self.final_flat]) |
|
|
|
return self.dense(hidden) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not exporting_to_onnx.is_exporting(): |
|
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2]) |
|
|
|
batch_size = visual_obs.shape[0] |
|
|
|
hidden = self.sequential(visual_obs).contiguous() |
|
|
|
before_out = hidden.view(batch_size, -1) |
|
|
|
hidden = self.sequential(visual_obs) |
|
|
|
before_out = hidden.reshape(batch_size, -1) |
|
|
|
return torch.relu(self.dense(before_out)) |