|
|
|
|
|
|
self, input_tensor: torch.Tensor, memories: torch.Tensor |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
# We don't use torch.split here since it is not supported by Barracuda |
|
|
|
h0 = memories[:, :, : self.hidden_size] |
|
|
|
c0 = memories[:, :, self.hidden_size :] |
|
|
|
h0 = memories[:, :, : self.hidden_size].contiguous() |
|
|
|
c0 = memories[:, :, self.hidden_size :].contiguous() |
|
|
|
hidden = (h0, c0) |
|
|
|
lstm_out, hidden_out = self.lstm(input_tensor, hidden) |
|
|
|
output_mem = torch.cat(hidden_out, dim=-1) |