问题描述
我正在尝试复制在另一个问题的所选答案的第 1. 段中描述的技术:How to pass additional parameters to numba cfunc passed as LowLevelCallable to scipy.integrate.quad。
但是,我不知道如何修改实现,使 xx[1] 是一个浮点数组而不是唯一的浮点数。
解决方法
我通过将 Jacques Gaudin 中的 https://stackoverflow.com/a/49732825/3925704 代码修改为:
import numpy as np
import scipy.integrate as si
import numba
from numba import cfunc
from numba.types import intc,CPointer,float64
from scipy import LowLevelCallable
def jit_integrand_function(integrand_function):
jitted_function = numba.jit(integrand_function,nopython=True)
@cfunc(float64(intc,CPointer(float64)))
def wrapped(n,xx):
values = carray(xx,n)
return jitted_function(values)
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def integrand(args):
t = args[0]
a = args[1]
return np.exp(-t/a) / t**2
def do_integrate(func,a):
"""
Integrate the given function from 1.0 to +inf with additional argument a.
"""
return si.quad(func,1,np.inf,args=(a,))