浏览代码

Make memory contiguous (#4804)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
e344fe79
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 4
      ml-agents/mlagents/trainers/torch/layers.py

4
ml-agents/mlagents/trainers/torch/layers.py


self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# We don't use torch.split here since it is not supported by Barracuda
h0 = memories[:, :, : self.hidden_size]
c0 = memories[:, :, self.hidden_size :]
h0 = memories[:, :, : self.hidden_size].contiguous()
c0 = memories[:, :, self.hidden_size :].contiguous()
hidden = (h0, c0)
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
正在加载...
取消
保存