|
|
|
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|
encoding = encoding.reshape([-1, sequence_length, self.h_size]) |
|
|
|
memories = torch.split(memories, self.m_size // 2, dim=-1) |
|
|
|
encoding, memories = self.lstm(encoding, (memories[0], memories[1])) |
|
|
|
encoding, memories = self.lstm(encoding, memories) |
|
|
|
encoding = encoding.reshape([-1, self.m_size // 2]) |
|
|
|
memories = torch.cat(memories, dim=-1) |
|
|
|
return encoding, memories |
|
|
|