浏览代码

Fix torch tensor non-contiguous issue (#4855)

* add contiguous

* preserve tensor device type
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
e9ff7705
共有 3 个文件被更改,包括 5 次插入5 次删除
  1. 3
      ml-agents/mlagents/trainers/torch/attention.py
  2. 6
      ml-agents/mlagents/trainers/torch/encoders.py
  3. 1
      ml-agents/mlagents/trainers/torch/networks.py

3
ml-agents/mlagents/trainers/torch/attention.py


with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks

6
ml-agents/mlagents/trainers/torch/encoders.py


if not exporting_to_onnx.is_exporting():
visual_obs = visual_obs.permute([0, 3, 1, 2])
hidden = self.conv_layers(visual_obs)
hidden = hidden.view([-1, self.final_flat])
hidden = hidden.reshape([-1, self.final_flat])
return self.dense(hidden)

if not exporting_to_onnx.is_exporting():
visual_obs = visual_obs.permute([0, 3, 1, 2])
batch_size = visual_obs.shape[0]
hidden = self.sequential(visual_obs).contiguous()
before_out = hidden.view(batch_size, -1)
hidden = self.sequential(visual_obs)
before_out = hidden.reshape(batch_size, -1)
return torch.relu(self.dense(before_out))

1
ml-agents/mlagents/trainers/torch/networks.py


if self.use_lstm and memories is not None:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
actor_mem, critic_mem = actor_mem.contiguous(), critic_mem.contiguous()
else:
critic_mem = None
actor_mem = None

正在加载...
取消
保存