|
|
|
|
|
|
torch.zeros(1, batch_size, self.m_size), |
|
|
|
) |
|
|
|
|
|
|
|
def update_normalization(self, inputs): |
|
|
|
def update_normalization(self, vec_inputs): |
|
|
|
self.normalizer.update(inputs) |
|
|
|
for idx, vec_input in enumerate(vec_inputs): |
|
|
|
self.vector_normalizers[idx].update(vec_input) |
|
|
|
|
|
|
|
def forward(self, vec_inputs, vis_inputs): |
|
|
|
vec_embeds = [] |
|
|
|
|
|
|
vec_input = self.normalizers[idx](vec_inputs[idx]) |
|
|
|
vec_input = self.vector_normalizers[idx](vec_input) |
|
|
|
hidden = encoder(vec_input) |
|
|
|
vec_embeds.append(hidden) |
|
|
|
|
|
|
|
|
|
|
vis_embeds.append(hidden) |
|
|
|
|
|
|
|
vec_embeds = torch.cat(vec_embeds) |
|
|
|
vis_embeds = torch.cat(vis_embeds) |
|
|
|
embedding = torch.cat([vec_embeds, vis_embeds]) |
|
|
|
if len(vec_embeds) > 0: |
|
|
|
vec_embeds = torch.cat(vec_embeds) |
|
|
|
if len(vis_embeds) > 0: |
|
|
|
vis_embeds = torch.cat(vis_embeds) |
|
|
|
if len(vec_embeds) > 0 and len(vis_embeds) > 0: |
|
|
|
embedding = torch.cat([vec_embeds, vis_embeds]) |
|
|
|
elif len(vec_embeds) > 0: |
|
|
|
embedding = vec_embeds |
|
|
|
else: |
|
|
|
embedding = vis_embeds |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
embedding, self.memory = self.lstm(embedding, self.memory) |
|
|
|
return embedding |
|
|
|
|
|
|
value_outputs = {} |
|
|
|
for stream_name, _ in self.value_heads.items(): |
|
|
|
value_outputs[stream_name] = self.value_heads[stream_name](hidden) |
|
|
|
return value_outputs, torch.mean(torch.stack(list(value_outputs)), dim=0) |
|
|
|
return ( |
|
|
|
value_outputs, |
|
|
|
torch.mean(torch.stack(list(value_outputs.values())), dim=0), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class VectorEncoder(nn.Module): |
|
|
|