浏览代码

Silencing a Torch warning that raises when exporting the attention module to ONNX

/bullet-hell-barracuda-test-1.3.1
vincentpierre 4 年前
当前提交
bd6ab0f7
共有 1 个文件被更改,包括 17 次插入9 次删除
  1. 26
      ml-agents/mlagents/trainers/torch/attention.py

26
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.torch_utils import torch
import warnings
from typing import Tuple, Optional, List
from mlagents.trainers.torch.layers import (
LinearEncoder,

with torch.no_grad():
if exporting_to_onnx.is_exporting():
# When exporting to ONNX, we want to transpose the entities. This is
# because ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
entities = [
torch.transpose(obs, 2, 1).reshape(
-1, int(obs.shape[1]), int(obs.shape[2])
)
for obs in entities
]
with warnings.catch_warnings():
# We ignore a TracerWarning from PyTorch that warns that doing
# shape[n].item() will cause the trace to be incorrect (the trace might
# not generalize to other inputs)
# We ignore this warning because we know the model will always be
# run with inputs of the same shape
warnings.simplefilter("ignore")
# When exporting to ONNX, we want to transpose the entities. This is
# because ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
entities = [
torch.transpose(obs, 2, 1).reshape(
-1, obs.shape[1].item(), obs.shape[2].item()
)
for obs in entities
]
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [

正在加载...
取消
保存