问题描述
我正在使用line profiler来分析一段Python代码的性能。
在代码中,我有一个形状为(106906,)和dtype = {tt
}的numpy数组int64
。在事件探查器的帮助下,我发现mask[tt]=True
下面的第二行很慢。反正有加速吗?如果那很重要,我将使用Python 3。
mask = np.zeros(100000,dtype='bool')
mask[tt] = True
解决方法
您可以按照@orlevii的建议使用Numba:
from numba import njit
@njit
def f(mask,tt):
mask[tt] = True
#Test:
mask = np.zeros(1000000,dtype='bool')
tt = np.random.randint(0,1000000,106906)
f(mask,tt)
简单的%%timeit
检查表明您应该期望执行速度快大约3倍。
利用GPU可以进一步提高速度。使用PyTorch的示例:
import torch
mask = torch.zeros(1000000).type(torch.cuda.FloatTensor)
tt = torch.randint(0,torch.Size([106906])).type(torch.cuda.LongTensor)
mask[tt] = True
请注意,这里我们使用torch.Tensor
对象,它与PyTorch中的numpy.ndarray
等效。仅当您拥有带有CUDA的(NVIDIA)GPU时,代码才会运行。希望在Tesla V100-SXM2上获得原始代码的30倍加速。