问题描述
我正在尝试加快Python中特定的(数字)积分的速度。我在Mathematica中进行了评估,需要14秒。在python中,它需要15.6分钟!
我要评估的积分形式为:
python代码如下:
from mpmath import hermite
def light_nm( dipol,n,m,t):
mat_elem = light_amp(n)*light_amp_conj(m)*coef_ground( dipol,t)*np.conj(coef_ground( dipol,t)) + \
light_amp(n+1)*light_amp_conj(m+1)*coef_excit( dipol,n+1,t)*np.conj(coef_excit( dipol,m+1,t))
return mat_elem
def light_nm_dmu( dipol,t):
mat_elem = light_amp(n)*light_amp_conj(m)*(coef_ground_dmu( dipol,t)*conj(coef_ground( dipol,t)) + coef_ground( dipol,t)*conj(coef_ground_dmu( dipol,t)) )+ \
light_amp(n+1)*light_amp_conj(m+1)*(coef_excit_dmu( dipol,t)) + coef_excit( dipol,t)*conj(coef_excit_dmu( dipol,t)))
return mat_elem
def prob(dipol,t,x,thlo,cutoff,m):
temp = complex( light_nm(dipol,t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*math.factorial(m)*math.factorial(n)*math.pi))
return np.real(temp)
def derprob(dipol,m):
temp = complex( light_nm_dmu(dipol,t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
hermite(n,x)/math.sqrt(2**(n+m)*math.factorial(m)*math.factorial(n)*math.pi))
if np.imag(temp)>10**(-6):
print(t)
return np.real(temp)
def integrand(dipol,x):
return 1/np.sum(np.array([ prob(dipol,m) for n,m in product(range(cutoff),range(cutoff))]))*\
np.sum(np.array([ derprob(dipol,range(cutoff))]))**2
def cfi(dipol,a):
global alpha
alpha = a
temp_func_real = lambda x: np.real(integrand(dipol,x))
temp_real = integ.quad(temp_func_real,-8,8)
return temp_real[0]
hermite函数是从mpmath库中调用的。 有什么方法可以使此代码更快地工作?
谢谢!
已更新: 我添加了整个代码。 (对不起,我很抱歉) 函数“ light_nm_dmu”类似于“ light_nm”。 我尝试了答案,但是在light_amp函数中收到错误“ TypeError:只有大小为1的数组可以转换为Python标量”,因此我将prob和derprob向量化。
相同评估的新时间为886.7085871696472 = 14.8分钟(cfi(0.1,1,40,1))
解决方法
建议使用:
-
使用缓存来加速大量数字(即Is math.factorial memorized?)的阶乘计算(Domenico De Felice修改答案)
更新代码
# use cached factorial function
def prob(dipol,t,x,thlo,cutoff,n,m):
temp = complex( light_nm(dipol,m,t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*factorial(m)*factorial(n)*math.pi))
return np.real(temp)
# Vectorize computation
def integrand(dipol,x):
xaxis = np.arange(0,cutoff)
yaxis = np.arange(0,cutoff)
return 1/np.sum(prob(dipol,xaxis[:,None],yaxis[None,:]))*\
np.sum(derprob(dipol,:]))**2
# unchanged
def cfi(dipol,a):
global alpha
alpha = a
temp_func_real = lambda x: np.real(integrand(dipol,x))
temp_real = integ.quad(temp_func_real,-8,8)
return temp_real[0]
# Cached factorial
def factorial(num,fact_memory = {0: 1,1: 1,'max': 1}):
' Cached factorial since we're computing on lots of numbers '
# Factorial is defined only for non-negative numbers
assert num >= 0
if num <= fact_memory['max']:
return fact_memory[num]
for x in range(fact_memory['max']+1,num+1):
fact_memory[x] = fact_memory[x-1] * x
fact_memory['max'] = num
return fact_memory[num]