Python curve_fit“对象对于所需数组太深”

问题描述

我想在 python 中绘制这个函数( y = a * ln(x) + b )。 这是我的代码

def func(x,a,b):
return a * np.log(x) + b

popt,_ = curve_fit(func,x,y)
a,b = popt
x_line = arrange(min(x),max(x),1)
y_line = func(x_line,b)
plt.plot(x_line,y_line)
plt.show()

我的“x”包含这个

array([[1790],[1800],[1810],[1820],[1830],[1840],[1850],[1860],[1870],[1880],[1900],[1910],[1920],[1930],[1940],[1950],[1960],[1970],[1980],[1990],[2000],[2010]],dtype=int64)

还有我的“y”这个

array([[  3.929214],[  5.308483],[  7.239881],[  9.638453],[ 12.86602 ],[ 17.069453],[ 23.191876],[ 31.443321],[ 39.818449],[ 50.189209],[ 76.212168],[ 92.228496],[106.021537],[123.202624],[132.164569],[151.325798],[179.323175],[203.302031],[226.542199],[248.718302],[281.424603],[308.745538]])

但是当我运行代码时,我总是得到这个错误

object too deep for desired array

我希望有人能帮助我,因为我花了很多时间。

解决方法

尝试重塑你的数组:

popt,_ = curve_fit(func,x.reshape(-1),y.reshape(-1))
,

您的 xy 变量是二维 (22 x 1) 数组,因为当 scipy.optimize.curve_fit 需要一维数组时,内部有一组方括号。

您可以删除内括号或切片 xy

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def func(x,a,b):
  return a * np.log(x) + b

popt,x[:,0],y[:,0])
a,b = popt
x_line = np.arange(min(x),max(x),1)
y_line = func(x_line,b)
plt.plot(x_line,y_line)
plt.show()