|
|
|
|
|
|
|
|
|
|
visual_encoder = ModelUtils.get_encoder_for_type(vis_encode_type) |
|
|
|
for vector_size in vector_sizes: |
|
|
|
self.vector_normalizers.append(Normalizer(vector_size)) |
|
|
|
self.vector_encoders.append(VectorEncoder(vector_size, h_size, num_layers)) |
|
|
|
if vector_size != 0: |
|
|
|
self.vector_normalizers.append(Normalizer(vector_size)) |
|
|
|
self.vector_encoders.append( |
|
|
|
VectorEncoder(vector_size, h_size, num_layers) |
|
|
|
) |
|
|
|
self.visual_encoders.append(visual_encoder(visual_size)) |
|
|
|
self.visual_encoders.append( |
|
|
|
visual_encoder(visual_size.num_channels, h_size) |
|
|
|
) |
|
|
|
|
|
|
|
self.vector_encoders = nn.ModuleList(self.vector_encoders) |
|
|
|
self.visual_encoders = nn.ModuleList(self.visual_encoders) |
|
|
|
|
|
|
|
|
|
|
vis_embeds = [] |
|
|
|
for idx, encoder in enumerate(self.visual_encoders): |
|
|
|
hidden = encoder(vis_inputs[idx]) |
|
|
|
vis_input = vis_inputs[idx] |
|
|
|
vis_input = vis_input.permute([0, 3, 1, 2]) |
|
|
|
hidden = encoder(vis_input) |
|
|
|
vis_embeds.append(hidden) |
|
|
|
|
|
|
|
if len(vec_embeds) > 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleVisualEncoder(nn.Module): |
|
|
|
def __init__(self, initial_channels): |
|
|
|
def __init__(self, initial_channels, output_size): |
|
|
|
self.h_size = output_size |
|
|
|
self.dense = nn.Linear(1728, self.h_size) |
|
|
|
return torch.flatten(conv_2) |
|
|
|
hidden = self.dense(conv_2.reshape([-1, 1728])) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
class NatureVisualEncoder(nn.Module): |
|
|
|