问题描述
我是使用scipy的curve_fit()的初学者。我不明白我的以下代码有什么问题:
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
def func(x,a,b,c):
return a * np.exp(-b * x) + c
xdata = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
ydata = [75,66,63,61,60,58,55,56,54,59,57,56]
popt,pcov = curve_fit(func,xdata,ydata)
它返回RuntimeWarning:exp中遇到溢出
有什么主意吗?预先感谢!
解决方法
如this post中所述,np.exp
很快就会溢出。您可以在b
上添加bounds来避免溢出。请注意,您只会收到警告,并且curve_fit
的结果不会受到影响。
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
def func(x,a,b,c):
return a * np.exp(-b * x) + c
xdata = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
ydata = [75,66,63,61,60,58,55,56,54,59,57,56]
popt,pcov = curve_fit(func,xdata,ydata,bounds=([-np.inf,0.0001,-np.inf],[np.inf,np.inf]))
xs = np.linspace(2,22,100)
plt.plot(xs,func(xs,*popt))
plt.scatter(xdata,ydata)
plt.show()
PS:还请注意,拟合函数使用x
的数据类型,这有时会引起奇怪的问题。在此示例中,没有问题,但通常可以添加xdata = np.array(xdata,dtype=float)
或xdata = np.array(xdata,dtype=np.longdouble)
。