fut = torch.jit._fork(
self.network_body, vec_inputs, vis_inputs, memories, sequence_length
)
embedding, memories = torch.jit._wait(fut)
return embedding, value_outputs, memories
@torch.jit.ignore