kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)
self.sequential = nn.Sequential(*self.layers)
self.sequential = nn.Sequential(*layers)
def forward(self, visual_obs):
batch_size = visual_obs.shape[0]