|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
import warnings |
|
|
|
from typing import Tuple, Optional, List |
|
|
|
from mlagents.trainers.torch.layers import ( |
|
|
|
LinearEncoder, |
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
# When exporting to ONNX, we want to transpose the entities. This is |
|
|
|
# because ONNX only support input in NCHW (channel first) format. |
|
|
|
# Barracuda also expect to get data in NCHW. |
|
|
|
entities = [ |
|
|
|
torch.transpose(obs, 2, 1).reshape( |
|
|
|
-1, int(obs.shape[1]), int(obs.shape[2]) |
|
|
|
) |
|
|
|
for obs in entities |
|
|
|
] |
|
|
|
with warnings.catch_warnings(): |
|
|
|
# We ignore a TracerWarning from PyTorch that warns that doing |
|
|
|
# shape[n].item() will cause the trace to be incorrect (the trace might |
|
|
|
# not generalize to other inputs) |
|
|
|
# We ignore this warning because we know the model will always be |
|
|
|
# run with inputs of the same shape |
|
|
|
warnings.simplefilter("ignore") |
|
|
|
# When exporting to ONNX, we want to transpose the entities. This is |
|
|
|
# because ONNX only support input in NCHW (channel first) format. |
|
|
|
# Barracuda also expect to get data in NCHW. |
|
|
|
entities = [ |
|
|
|
torch.transpose(obs, 2, 1).reshape( |
|
|
|
-1, obs.shape[1].item(), obs.shape[2].item() |
|
|
|
) |
|
|
|
for obs in entities |
|
|
|
] |
|
|
|
|
|
|
|
# Generate the masking tensors for each entities tensor (mask only if all zeros) |
|
|
|
key_masks: List[torch.Tensor] = [ |
|
|
|