问题描述
我正在尝试将scipy.optimize.curve_fit
与自定义拟合函数结合使用(大致遵循this教程):
# Fit function
def fit_function(x,y,x0,y0,A,FWHM):
return A*np.exp(1)*4*np.log(2)*((x+x0)**2 + (y+y0)**2)/FWHM**2*np.exp(-4*np.log(2)*((x+x0)**2 + (y+y0)**2)/FWHM**2)
# Open image file
img = Image.open('/home/user/image.tif')
# xdata
X,Y = img.size
xRange = np.arange(1,X+1)
yRange = np.arange(1,Y+1)
xGrid,yGrid = np.meshgrid(xRange,yRange)
xyGrid = np.vstack((xGrid.ravel(),yGrid.ravel()))
# ydata
imgArray = np.array(img)
imgArrayFlat = imgArray.ravel()
# Fitting
params_opt,params_cov = curve_fit(fit_function,xyGrid,imgArrayFlat)
由于某些原因,Jupyter Notebook不断抛出此错误,而我在代码中找不到该问题:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-61-eaa3ebdb6469> in <module>()
17 imgArrayFlat = imgArray.ravel() # Flatten 2D pixel data into 1D array for scipy.optimize.curve_fit
18
---> 19 params_opt,params_cov = curve_fit(doughnut,imgArrayFlat)
/usr/lib/python3/dist-packages/scipy/optimize/minpack.py in curve_fit(f,xdata,ydata,p0,sigma,absolute_sigma,check_finite,bounds,method,jac,**kwargs)
749 # Remove full_output from kwargs,otherwise we're passing it in twice.
750 return_full = kwargs.pop('full_output',False)
--> 751 res = leastsq(func,Dfun=jac,full_output=1,**kwargs)
752 popt,pcov,infodict,errmsg,ier = res
753 cost = np.sum(infodict['fvec'] ** 2)
/usr/lib/python3/dist-packages/scipy/optimize/minpack.py in leastsq(func,args,Dfun,full_output,col_deriv,ftol,xtol,gtol,maxfev,epsfcn,factor,diag)
384 m = shape[0]
385 if n > m:
--> 386 raise TypeError('Improper input: N=%s must not exceed M=%s' % (n,m))
387 if epsfcn is None:
388 epsfcn = finfo(dtype).eps
TypeError: Improper input: N=5 must not exceed M=2
我不明白N
和M
指的是什么,但是我读过某个地方,当数据点少于参数(系统欠定)时,会抛出此错误-这是不是这里的情况,因为图像文件每个都有大约15 x 15 = 225个数据点。可能是什么原因引起的麻烦?
解决方法
可能您需要将功能更改为
def fit_function(X,x0,y0,A,FWHM):
x,y = X
return A*np.exp(1)*4*np.log(2)*((x+x0)**2 + (y+y0)**2)/FWHM**2*np.exp(-4*np.log(2)*((x+x0)**2 + (y+y0)**2)/FWHM**2)
因为只有第一个变量被视为独立变量。
当前,您在x
变量内发送了一个数组,该变量是从两个1D数组np.vstack
开始的,因此M=2
:您有两个数据点。在该函数中,所有其他参数都被视为要优化的参数(包括y
!),因此是N=5
。