cp.array 的慢切片,索引为零维 cp.array基于 cp.argmin 结果

问题描述

我有一些代码,我需要根据 cp.argmin 应用于较小 cp.array 的结果对较大的 cp.array 进行切片。 (见下面的最小代码示例)

问题是,cp.argmin 返回一个零维的 cp.array,而使用 : 运算符进行切片显然需要整数。

import time
import cupy as cp

original = cp.empty((10000,10000))
nrows,ncols = 1000,1000
to_modify = cp.empty((nrows,ncols))
start_time = time.time()
for i in range(10000):
    argmin = cp.argmin(to_modify)
    argmin = int(argmin)
    row_idx,col_idx = (argmin // ncols,argmin % ncols)
    sliced = original[row_idx : row_idx + nrows,col_idx : col_idx + ncols]
    to_modify += sliced

print(time.time() - start_time)

当我分析上面的代码时(我使用 py-spy),最慢的行(大约 90% 的时间)是转换为 argmin 的 int,但是如果我删除它,行 sliced = original[ ... ] 成为最慢的行,因为演员表似乎是隐含地发生的。 有没有办法以一种高效的方式解决我的问题,在切片时避免使用 : 运算符?

解决方法

不幸的是,我没有找到使用 cupy 的可行解决方案。 相反,我开始使用 numba。虽然 numba 需要以编写 cuda 内核的形式进行更多的手动工作,但它也提供了更多的控制。 使用 numba.cuda.device_array()、numba.cuda.to_device() 和 numba.cuda.copy_to_host() 还可以控制何时在 cpu 和 gpu 之间来回复制数组。 最难的部分是实现 argmin,它需要一个 reduce 操作:

@cuda.jit
def find_argmin(data,argmin,tmp_min,tmp_idx):
    x = cuda.grid(1)

    shape_x,shape_y = data.shape[0],data.shape[1]
    num_items = shape_x * shape_y
    num_threads = tmp_min.shape[0]
    num_items_per_thread = num_items // num_threads

    min_val = data[0,0]
    min_ix,min_iy = 0,0

    for idx in range(num_items_per_thread):
        idx = x * num_items_per_thread + idx
        if idx < num_items:
            ix,iy = idx // shape_y,idx % shape_y
            current = data[ix,iy]
            if current < min_val:
                min_val = current
                min_ix = ix
                min_iy = iy

    tmp_idx[x,0] = min_ix
    tmp_idx[x,1] = min_iy
    tmp_min[x] = min_val

    cuda.syncthreads()

    # find minimum in temporary array
    if x == 0:
        min_val = tmp_min[0]
        min_x,min_y = tmp_idx[0,0],tmp_idx[0,1]

        for idx in range(num_threads):
            if tmp_min[idx] < min_val:
                min_val = tmp_min[idx]
                min_x,min_y = tmp_idx[idx,tmp_idx[idx,1]

    argmin[0] = min_x
    argmin[1] = min_y

其余的都是简单的元素操作。