|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
net_input = inputs[idx] |
|
|
|
net_input = net_inputs[idx] |
|
|
|
if not exporting_to_onnx.is_exporting() and len(net_input.shape) > 3: |
|
|
|
net_input = net_input.permute([0, 3, 1, 2]) |
|
|
|
processed_vec = processor(net_input) |
|
|
|