浏览代码

Develop fix visual torchh export (#4494) (#4497)

* Fixing exporting of ONNX for visual when using threading

* docstring was wrong
/release_7_branch
GitHub 4 年前
当前提交
37645aa2
共有 2 个文件被更改,包括 39 次插入10 次删除
  1. 46
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 3
      ml-agents/mlagents/trainers/torch/networks.py

46
ml-agents/mlagents/trainers/torch/model_serialization.py


import os
import threading
from mlagents.torch_utils import torch
from mlagents_envs.logging_util import get_logger

logger = get_logger(__name__)
class exporting_to_onnx:
"""
Set this context by calling
```
with exporting_to_onnx():
```
Within this context, the variable exporting_to_onnx.is_exporting() will be true.
This implementation is thread safe.
"""
_local_data = threading.local()
_local_data._is_exporting = False
def __enter__(self):
self._local_data._is_exporting = True
def __exit__(self, *args):
self._local_data._is_exporting = False
@staticmethod
def is_exporting():
if not hasattr(exporting_to_onnx._local_data, "_is_exporting"):
return False
return exporting_to_onnx._local_data._is_exporting
class ModelSerializer:

onnx_output_path = f"{output_filepath}.onnx"
logger.info(f"Converting to {onnx_output_path}")
torch.onnx.export(
self.policy.actor_critic,
self.dummy_input,
onnx_output_path,
opset_version=SerializationSettings.onnx_opset,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
)
with exporting_to_onnx():
torch.onnx.export(
self.policy.actor_critic,
self.dummy_input,
onnx_output_path,
opset_version=SerializationSettings.onnx_opset,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
)
logger.info(f"Exported {onnx_output_path}")

3
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

for idx, processor in enumerate(self.visual_processors):
vis_input = vis_inputs[idx]
if not torch.onnx.is_in_onnx_export():
if not exporting_to_onnx.is_exporting():
vis_input = vis_input.permute([0, 3, 1, 2])
processed_vis = processor(vis_input)
encodes.append(processed_vis)

正在加载...
取消
保存