Numba CUDA共享内存矩阵乘法

问题描述

我正在运行用于矩阵乘法的共享内存numba代码,但是我认为解决该问题的算法不正确,因为我得到的结果不正确

我看到了该代码的另一个线程,但答案仍未公开,代码无法正常工作

代码如下:

# This part is for initializing everything
M = 128
N = 32


a = np.arange(M*N).reshape(M,N).astype(np.int32)
b = np.arange(M*N).reshape(N,M).astype(np.int32)
c = np.zeros((M,M)).astype(np.int32)

d_a = cuda.to_device(a)
d_b = cuda.to_device(b)
d_c = cuda.to_device(c)

block_size = (N,N)
grid_size = (int(M/N),int(M/N))

这是我的内核定义:

import numpy as np
from numba import cuda,types
@cuda.jit
def fast_matmul(A,B,C):
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    TPB = N
    
    sA = cuda.shared.array(shape=(TPB,TPB),dtype=float32)
    sB = cuda.shared.array(shape=(TPB,dtype=float32)

    x,y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x    # blocks per grid

    if x >= C.shape[0] and y >= C.shape[1]:
        # Quit if (x,y) is outside of valid C boundary
        return

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = 0.
    for i in range(bpg):
        # Preload data into shared memory
        sA[tx,ty] = A[x,ty + i * TPB]
        sB[tx,ty] = B[tx + i * TPB,y]

        # Wait until all threads finish preloading
        cuda.syncthreads()

        # Computes partial product on the shared memory
        for j in range(TPB):
            tmp += sA[tx,j] * sB[j,ty]

        # Wait until all threads finish computing
        cuda.syncthreads()

    C[x,y] = tmp

我从这里跟随:https://numba.pydata.org/numba-doc/dev/cuda/examples.html

但是运行代码会给我带来奇怪的结果,例如

x: array([[2147483647,2147483647,...,2147483647],[2147483647,...

应该是这样的时候:

 y: array([[  1333248,1333744,1334240,1395248,1395744,1396240],[  3364864,3366384,3367904,3554864,3556384,...

有人可以指出我要去哪里了吗?

解决方法

  1. 您要将int32数组传递给期望float32数据的numba内核。不要那样做。
  2. 您实际上还没有显示内核启动。请提供完整的代码。
  3. 不清楚xy是什么。您的代码仅产生一个结果,并且在d_c中。
  4. 我还怀疑您的输入数据将溢出int32类型。您可能应该始终转换为float32。使用arange使得难以快速/直观地验证数字正确性。当您尝试使事情正常运行时,请改用ones

以下是您所展示内容的一个版本,其中考虑了上述想法:

$ cat t35.py
import numpy as np
from numba import cuda,types,float32
@cuda.jit
def fast_matmul(A,B,C):
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    TPB = N

    sA = cuda.shared.array(shape=(TPB,TPB),dtype=float32)
    sB = cuda.shared.array(shape=(TPB,dtype=float32)

    x,y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x    # blocks per grid

    if x >= C.shape[0] and y >= C.shape[1]:
        # Quit if (x,y) is outside of valid C boundary
        return

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = 0.
    for i in range(bpg):
        # Preload data into shared memory
        sA[tx,ty] = A[x,ty + i * TPB]
        sB[tx,ty] = B[tx + i * TPB,y]

        # Wait until all threads finish preloading
        cuda.syncthreads()

        # Computes partial product on the shared memory
        for j in range(TPB):
            tmp += sA[tx,j] * sB[j,ty]

        # Wait until all threads finish computing
        cuda.syncthreads()

    C[x,y] = tmp

# This part is for initializing everything
M = 128
N = 32


#a = np.arange(M*N).reshape(M,N).astype(np.float32)
#b = np.arange(M*N).reshape(N,M).astype(np.float32)
a = np.ones(M*N).reshape(M,N).astype(np.float32)
b = np.ones(M*N).reshape(N,M).astype(np.float32)
c = np.zeros((M,M)).astype(np.float32)

d_a = cuda.to_device(a)
d_b = cuda.to_device(b)
d_c = cuda.to_device(c)

block_size = (N,N)
grid_size = (int(M/N),int(M/N))

fast_matmul[grid_size,block_size](d_a,d_b,d_c)
c = d_c.copy_to_host()
print(c)
$ python t35.py
[[32. 32. 32. ... 32. 32. 32.]
 [32. 32. 32. ... 32. 32. 32.]
 [32. 32. 32. ... 32. 32. 32.]
 ...
 [32. 32. 32. ... 32. 32. 32.]
 [32. 32. 32. ... 32. 32. 32.]
 [32. 32. 32. ... 32. 32. 32.]]
$

我相信32是正确的答案。

还请注意,此示例的发布示例可能有一些错误,请参见here

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...