您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
41 行
1.3 KiB
41 行
1.3 KiB
import pytest
|
|
from unittest import mock
|
|
|
|
import torch # noqa I201
|
|
|
|
from mlagents.torch_utils import set_torch_config, default_device
|
|
from mlagents.trainers.settings import TorchSettings
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"device_str, expected_type, expected_index, expected_tensor_type",
|
|
[
|
|
("cpu", "cpu", None, torch.FloatTensor),
|
|
("cuda", "cuda", None, torch.cuda.FloatTensor),
|
|
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
|
|
("opengl", "opengl", None, torch.FloatTensor),
|
|
],
|
|
)
|
|
@mock.patch.object(torch, "set_default_tensor_type")
|
|
def test_set_torch_device(
|
|
mock_set_default_tensor_type,
|
|
device_str,
|
|
expected_type,
|
|
expected_index,
|
|
expected_tensor_type,
|
|
):
|
|
try:
|
|
torch_settings = TorchSettings(device=device_str)
|
|
set_torch_config(torch_settings)
|
|
assert default_device().type == expected_type
|
|
if expected_index is None:
|
|
assert default_device().index is None
|
|
else:
|
|
assert default_device().index == expected_index
|
|
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type)
|
|
except Exception:
|
|
raise
|
|
finally:
|
|
# restore the defaults
|
|
torch_settings = TorchSettings(device=None)
|
|
set_torch_config(torch_settings)
|