Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

41 行
1.3 KiB

import pytest
from unittest import mock
import torch # noqa I201
from mlagents.torch_utils import set_torch_config, default_device
from mlagents.trainers.settings import TorchSettings
@pytest.mark.parametrize(
"device_str, expected_type, expected_index, expected_tensor_type",
[
("cpu", "cpu", None, torch.FloatTensor),
("cuda", "cuda", None, torch.cuda.FloatTensor),
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
("opengl", "opengl", None, torch.FloatTensor),
],
)
@mock.patch.object(torch, "set_default_tensor_type")
def test_set_torch_device(
mock_set_default_tensor_type,
device_str,
expected_type,
expected_index,
expected_tensor_type,
):
try:
torch_settings = TorchSettings(device=device_str)
set_torch_config(torch_settings)
assert default_device().type == expected_type
if expected_index is None:
assert default_device().index is None
else:
assert default_device().index == expected_index
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type)
except Exception:
raise
finally:
# restore the defaults
torch_settings = TorchSettings(device=None)
set_torch_config(torch_settings)