如何在python中拟合三个高斯峰?

问题描述

我正在尝试使用python拟合三个峰。我能够拟合第一个峰,但是在将拟合函数收敛到下两个峰时遇到问题。有人可以帮我吗?

我想最初的猜测有问题!

这是代码和数字:

from __future__ import division
import numpy as np
import scipy.signal
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['font.family'] = 'sans-serif'

""" Fitting Function"""
def _2gauss(x,amp1,cen1,sigma1,amp2,cen2,sigma2):
    return amp1*(1/(sigma1*(np.sqrt(2*np.pi))))*(np.exp((-1.0/2.0)*(((x-cen1)/sigma1)**2))) + \
    amp2*(1/(sigma2*(np.sqrt(2*np.pi))))*(np.exp((-1.0/2.0)*(((x-cen2)/sigma2)**2)))+ \
    amp3*(1/(sigma3*(np.sqrt(2*np.pi))))*(np.exp((-1.0/2.0)*(((x-cen3)/sigma3)**2))) 
data_12 = np.loadtxt("ExcitationA.txt",skiprows=30,dtype=np.float64)
xData,yData = np.hsplit(data_12,2)
x = xData[:,0]
y = yData[:,0]
n = len(x)
amp1 = 400
sigma1 = 10
cen1 = 400

amp2 = 400
sigma2 = 5
cen2 = 400

amp3 = 340
sigma3 = 6
cen3 = 340
popt,pcov = curve_fit(_2gauss,x,y,p0= [amp1,sigma2])
fig,ax = plt.subplots(figsize=(8,6))
ax.plot(x,'b',markersize=1,label="12°C")
ax.plot(x,_2gauss(x,*popt),markersize='1',label="Fit function",linewidth=4,color='purple')
plt.show()

enter image description here

解决方法

由于有9个参数,为了获得良好的拟合度,这些参数的初始值应接近。一个想法是尝试绘画

p0 = [amp1,cen1,sigma1,amp2,cen2,sigma2,amp3,cen3,sigma3]
ax.plot(x,_2gauss(x,*p0))

直到参数或多或少相等。在此示例中,重要的是,中心cen1cen2cen3接近观测到的局部最大值(340、355、375)。

一旦有了合理的初始值,就可以开始拟合。还要注意,在最初发布的示例代码中,amp3,sigma3作为函数_2gauss的参数缺失。

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def gauss_1(x,amp1,sigma1):
    return amp1 * (1 / (sigma1 * (np.sqrt(2 * np.pi)))) * (np.exp((-1.0 / 2.0) * (((x - cen1) / sigma1) ** 2)))

def gauss_3(x,sigma3):
    """ Fitting Function"""
    return amp1 * (1 / (sigma1 * (np.sqrt(2 * np.pi)))) * (np.exp((-1.0 / 2.0) * (((x - cen1) / sigma1) ** 2))) + \
           amp2 * (1 / (sigma2 * (np.sqrt(2 * np.pi)))) * (np.exp((-1.0 / 2.0) * (((x - cen2) / sigma2) ** 2))) + \
           amp3 * (1 / (sigma3 * (np.sqrt(2 * np.pi)))) * (np.exp((-1.0 / 2.0) * (((x - cen3) / sigma3) ** 2)))

x = np.array([300.24,301.4,302.56,303.72,304.88,306.04,307.2,308.36,309.51,310.67,311.83,312.99,314.04,314.93,315.77,316.56,317.3,318.03,318.77,319.5,320.23,321.02,321.86,325.76,326.6,327.54,328.49,329.17,329.69,330.27,330.84,331.16,335.85,336.37,337.05,337.79,339.58,341.43,342.42,343.87,345.01,346.07,346.91,347.53,348.06,348.53,348.89,351.33,351.8,352.11,352.42,352.75,353.15,353.6,354.04,354.36,354.87,355.77,356.72,357.36,357.83,358.25,358.69,358.96,359.29,359.61,359.93,360.25,360.58,360.86,361.16,361.39,361.61,361.96,362.3,362.62,363.0,363.43,363.94,364.55,365.18,366.14,367.3,368.19,368.82,369.45,370.03,371.07,371.54,371.96,372.31,372.69,373.11,373.52,373.99,374.67,375.68,376.58,377.11,377.54,377.81,378.09,378.4,378.71,378.94,379.08,379.3,379.52,379.73,379.95,380.17,380.34,380.61,380.82,380.99,381.22,381.44,381.66,381.88,382.1,382.32,382.53,382.75,382.97,383.24,383.74,384.0,384.28,384.49,384.71,384.92,385.14,385.36,385.58,385.9,386.26,386.6,386.92,387.29,387.71,388.31,388.84,389.53,390.38,391.39,392.56,393.72,394.89,396.05,397.22,397.69,398.38,398.86,399.54,400.02,400.71,401.18,401.87,402.34,403.03,403.19,404.19,405.36,406.52,407.68,408.84,410.01,411.17,412.33,413.49,414.65,415.81,416.98,417.61])
y = np.array([3.6790e-01,4.1930e-01,4.6530e-01,5.1130e-01,5.6300e-01,6.1750e-01,6.6780e-01,7.2950e-01,7.8830e-01,8.4960e-01,9.0950e-01,9.6660e-01,1.0463e+00,1.1324e+00,1.2241e+00,1.3026e+00,1.3889e+00,1.4780e+00,1.5598e+00,1.6432e+00,1.7318e+00,1.8256e+00,1.9050e+00,2.1595e+00,2.2477e+00,2.3343e+00,2.4183e+00,2.5115e+00,2.5970e+00,2.6825e+00,2.7657e+00,2.8198e+00,3.8983e+00,3.9956e+00,4.0846e+00,4.1526e+00,4.2787e+00,4.2256e+00,4.2412e+00,4.2731e+00,4.3265e+00,4.4073e+00,4.4905e+00,4.5831e+00,4.6717e+00,4.7660e+00,4.8395e+00,5.6288e+00,5.7239e+00,5.8141e+00,5.9076e+00,6.0026e+00,6.1034e+00,6.2157e+00,6.3235e+00,6.4114e+00,6.5063e+00,6.5709e+00,6.5175e+00,6.4349e+00,6.3479e+00,6.2638e+00,6.2102e+00,6.0616e+00,5.9664e+00,5.8697e+00,5.7625e+00,5.6546e+00,5.5494e+00,5.4404e+00,5.3384e+00,5.2396e+00,5.1462e+00,5.0412e+00,4.9467e+00,4.8592e+00,4.7655e+00,4.6709e+00,4.5807e+00,4.4803e+00,4.3947e+00,4.3347e+00,4.3286e+00,4.3918e+00,4.4800e+00,4.5637e+00,4.6489e+00,4.8435e+00,4.9454e+00,5.0396e+00,5.1258e+00,5.2200e+00,5.3082e+00,5.3945e+00,5.4874e+00,5.5974e+00,5.6396e+00,5.5880e+00,5.4984e+00,5.4082e+00,5.3213e+00,5.2270e+00,5.1271e+00,5.0247e+00,4.9258e+00,4.8324e+00,4.7317e+00,4.6336e+00,4.5323e+00,4.4258e+00,4.3166e+00,4.2152e+00,4.1011e+00,3.9754e+00,3.8646e+00,3.7401e+00,3.6061e+00,3.4715e+00,3.3381e+00,3.2120e+00,3.0865e+00,2.9610e+00,2.8361e+00,2.7126e+00,2.6289e+00,2.2796e+00,2.1818e+00,2.0747e+00,1.9805e+00,1.8864e+00,1.7942e+00,1.7080e+00,1.6236e+00,1.5279e+00,1.4145e+00,1.2931e+00,1.1805e+00,1.0785e+00,9.8490e-01,8.9590e-01,7.9850e-01,7.0670e-01,6.2110e-01,5.2990e-01,4.4250e-01,3.7360e-01,3.1090e-01,2.5880e-01,2.0680e-01,1.6760e-01,1.4570e-01,1.2690e-01,1.1060e-01,9.5900e-02,9.0600e-02,8.0600e-02,7.0600e-02,5.8100e-02,4.4200e-02,4.1400e-02,3.4900e-02,2.4200e-02,1.9600e-02,1.5300e-02,1.5000e-02,1.1800e-02,1.3200e-02,7.8000e-03,5.0000e-03,1.0000e-02,4.6000e-03,0.0])
amp1 = 100
sigma1 = 9
cen1 = 375
amp2 = 100
sigma2 = 7
cen2 = 355
amp3 = 100
sigma3 = 10
cen3 = 340
p0 = [amp1,sigma3]
y0 = gauss_3(x,*p0)

popt,pcov = curve_fit(gauss_3,x,y,p0=p0)

fig,ax = plt.subplots(figsize=(8,6))
ax.plot(x,'b',label="given curve")
ax.plot(x,y0,'g',ls=':',label="initial fit params")
ax.plot(x,gauss_3(x,*popt),label="Fit function",linewidth=4,color='purple')
for i,(a,c,s )in enumerate( popt.reshape(-1,3)):
    ax.plot(x,gauss_1(x,a,s),ls='-',label=f"gauss {i+1}",linewidth=1,color='crimson')
ax.legend()
ax.autoscale(axis='x',tight=True)
plt.show()

example plot