如何加速numpy数组屏蔽?

问题描述

我正在使用line profiler来分析一段Python代码的性能。

在代码中,我有一个形状为(106906,)和dtype = {tt}的numpy数组int64。在事件探查器的帮助下,我发现mask[tt]=True下面的第二行很慢。反正有加速吗?如果那很重要,我将使用Python 3。

   mask = np.zeros(100000,dtype='bool')
   mask[tt] = True

解决方法

您可以按照@orlevii的建议使用Numba:

from numba import njit
@njit
def f(mask,tt):
    mask[tt] = True
#Test:
mask = np.zeros(1000000,dtype='bool')
tt = np.random.randint(0,1000000,106906)
f(mask,tt)

简单的%%timeit检查表明您应该期望执行速度快大约3倍。

利用GPU可以进一步提高速度。使用PyTorch的示例:

import torch
mask = torch.zeros(1000000).type(torch.cuda.FloatTensor)
tt = torch.randint(0,torch.Size([106906])).type(torch.cuda.LongTensor)
mask[tt] =  True

请注意,这里我们使用torch.Tensor对象,它与PyTorch中的numpy.ndarray等效。仅当您拥有带有CUDA的(NVIDIA)GPU时,代码才会运行。希望在Tesla V100-SXM2上获得原始代码的30倍加速。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...