您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
43 行
1.3 KiB
43 行
1.3 KiB
import os
|
|
|
|
# Detect availability of torch package here.
|
|
# NOTE: this try/except is temporary until torch is required for ML-Agents.
|
|
try:
|
|
# This should be the only place that we import torch directly.
|
|
# Everywhere else is caught by the banned-modules setting for flake8
|
|
import torch # noqa I201
|
|
|
|
if "TORCH_NUM_THREADS" in os.environ:
|
|
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS")))
|
|
|
|
if "TORCH_NUM_INTEROP" in os.environ:
|
|
torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP")))
|
|
# torch.set_num_interop_threads(4)
|
|
# os.environ["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
|
|
# os.environ["KMP_BLOCKTIME"] = "1"
|
|
|
|
# Known PyLint compatibility with PyTorch https://github.com/pytorch/pytorch/issues/701
|
|
# pylint: disable=E1101
|
|
if torch.cuda.is_available():
|
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
|
device = torch.device("cuda")
|
|
else:
|
|
torch.set_default_tensor_type(torch.FloatTensor)
|
|
device = torch.device("cpu")
|
|
nn = torch.nn
|
|
# pylint: disable=E1101
|
|
except ImportError:
|
|
torch = None
|
|
nn = None
|
|
device = None
|
|
|
|
|
|
def default_device():
|
|
return device
|
|
|
|
|
|
def is_available():
|
|
"""
|
|
Returns whether Torch is available in this Python environment
|
|
"""
|
|
return torch is not None
|