用Numba和CUDA求和的数组

问题描述

我刚刚开始学习如何使用Numba和CUDA进行编程,因此此代码可能是非常错误的,但是我不明白为什么它不起作用。我正在尝试对N个不同的数组求和,其内容取决于另一个数组。显示代码可能比以下解释更好:

import numba as nb
from numba import cuda
import numpy as np
from math import exp,ceil

t0s = np.array([2.5,6.7,8.1,9.6,10.5])
threadsperblock = 32
blockspergrid = ceil(t0s.shape[0] / threadsperblock)

time = np.linspace(0,10,2000)
waveform = np.zeros_like(time)
total_waveform = np.zeros_like(waveform)

@cuda.jit(device=True)
def current(waveform,time,t0):
    for i in range(waveform.shape[0]):
        if time[i] > t0:
            waveform[i] = 0
        else:
            waveform[i] = exp(time[i]-t0)

@cuda.jit
def total(time,waveform,total_waveform,t0s):
    i = cuda.grid(1)
    if i < t0s.shape[0]:
        current(waveform,t0s[i]) 
        for j in range(total_waveform.shape[0]):
            total_waveform[j] += waveform[j]

total[blockspergrid,threadsperblock](time,t0s)

不幸的是,total_waveform仅包含第一个波形(就像在t0s的第一个元素之后停止一样),我真的不明白为什么。救命! :)

解决方法

基于已发布的代码和此注释:

我的正确结果将是一个包含5条上升的指数曲线的数组,每条曲线都以t0s[i]结尾

假设您的意思是,您似乎可以极大地简化代码并获得所需的结果

我的正确结果是一个数组,其中包含 5条上升的指数曲线的总和,每条曲线的终点为t0s[i]

当t0较大时,每条曲线在小t处接近零,而对于所有t0> 0,每条曲线在[0,t0)上始终不为零。如果我没有误解您的意图和代码,您可以: / p>

  1. current更改为标量函数
  2. 一起消除waveform,这是不需要存储的中间结果
  3. 更改并行化策略,以便每个线程仅计算输出中的一个时间点(即,从原始代码中反转循环的顺序)。如果这样做,则不会出现内存争用或同步问题。

如果您做这三件事,您将得到如下信息:

$ cat wavegoodbye.py 
import numba as nb
from numba import cuda
import numpy as np
from math import exp,ceil

t0s = np.array([2.5,6.7,8.1,9.6,10.5])
time = np.linspace(0,10,2000)

total_waveform = np.zeros_like(time)

threadsperblock = 32
blockspergrid = ceil(total_waveform.shape[0] / threadsperblock)

@cuda.jit(device=True)
def current(time,t0):
    if time > t0:
       waveform = 0
    else:
       waveform = exp(time-t0)

    return waveform

@cuda.jit
def total(time,total_waveform,t0s):
    i = cuda.grid(1)
    if i < total_waveform.shape[0]:
        for j in range(t0s.shape[0]):
            total_waveform[i] += current(time[i],t0s[j]) 

total[blockspergrid,threadsperblock](time,t0s)

这样做:

$ ipython
Python 3.7.4 (default,Aug 13 2019,20:35:49) 
Type 'copyright','credits' or 'license' for more information
IPython 7.11.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: %run wavegoodbye.py                                                                                                                          

In [2]: import pylab as pl                                                                                                                           

In [3]: pl.plot(time,total_waveform)

enter image description here

我想这就是你的想法。