浏览代码

Don't use torch.split in LSTM

/develop/add-fire/export-discrete
Ervin Teng 4 年前
当前提交
aeda0b32
共有 1 个文件被更改,包括 3 次插入1 次删除
  1. 4
      ml-agents/mlagents/trainers/torch/layers.py

4
ml-agents/mlagents/trainers/torch/layers.py


def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
h0, c0 = torch.split(memories, self.hidden_size, dim=-1)
# We don't use torch.split here since it is not supported by Barracuda
h0 = memories[:, :, : self.hidden_size]
c0 = memories[:, :, self.hidden_size :]
hidden = (h0, c0)
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
正在加载...
取消
保存