您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
39 行
1.1 KiB
39 行
1.1 KiB
import os
|
|
|
|
from mlagents.torch_utils import cpu_utils
|
|
|
|
# Detect availability of torch package here.
|
|
# NOTE: this try/except is temporary until torch is required for ML-Agents.
|
|
try:
|
|
# This should be the only place that we import torch directly.
|
|
# Everywhere else is caught by the banned-modules setting for flake8
|
|
import torch # noqa I201
|
|
|
|
torch.set_num_threads(cpu_utils.get_num_threads_to_use())
|
|
os.environ["KMP_BLOCKTIME"] = "0"
|
|
|
|
# Known PyLint compatibility with PyTorch https://github.com/pytorch/pytorch/issues/701
|
|
# pylint: disable=E1101
|
|
if torch.cuda.is_available():
|
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
|
device = torch.device("cuda")
|
|
else:
|
|
torch.set_default_tensor_type(torch.FloatTensor)
|
|
device = torch.device("cpu")
|
|
nn = torch.nn
|
|
# pylint: disable=E1101
|
|
except ImportError:
|
|
torch = None
|
|
nn = None
|
|
device = None
|
|
|
|
|
|
def default_device():
|
|
return device
|
|
|
|
|
|
def is_available():
|
|
"""
|
|
Returns whether Torch is available in this Python environment
|
|
"""
|
|
return torch is not None
|