您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
42 行
713 B
42 行
713 B
import time
|
|
import torch
|
|
import numpy as np
|
|
|
|
a = torch.tensor([2., 2.])
|
|
batch_size = 32
|
|
print(a.shape[0])
|
|
|
|
t0 = time.time()
|
|
for i in range(1000):
|
|
b = a.expand(batch_size, -1)
|
|
t1 = time.time()
|
|
print(t0 - t1)
|
|
|
|
t0 = time.time()
|
|
for i in range(1000):
|
|
b = a * torch.ones(batch_size, a.shape[0])
|
|
t1 = time.time()
|
|
|
|
print(t0 - t1)
|
|
|
|
t0 = time.time()
|
|
for i in range(1000):
|
|
b = torch.cat([a for _ in range(batch_size)], axis=0)
|
|
t1 = time.time()
|
|
|
|
print(t0 - t1)
|
|
|
|
t0 = time.time()
|
|
for i in range(1000):
|
|
b = torch.cat([a] * batch_size, axis=0)
|
|
t1 = time.time()
|
|
|
|
print(t0 - t1)
|
|
|
|
|
|
t0 = time.time()
|
|
for i in range(1000):
|
|
b = a * torch.tensor(np.ones((batch_size, a.shape[0])))
|
|
t1 = time.time()
|
|
|
|
print(t0 - t1)
|