问题描述
我正在尝试使用 scipy 函数创建多项式插值。我创建了一些函数来遍历我的所有数据点(实际上是 5 个数据点),但我想对最小值添加一个约束。我的代码如下:
import numpy as np
from scipy.interpolate import *
import matplotlib.pyplot as plt
x = np.array([10,25,50,75,90])
y = np.array([5.7239,5.53,5.37,5.43,5.789])
poly = lagrange(x,y)
xpol = np.linspace(x[0],x[4],1000)
ypol = poly(xpol)
plt.scatter(x,y,marker='s',c='r')
plt.plot(xpol,ypol)
问题是点 (50,5.37) 应该是我的抛物线的最小值(参见图中的黑线),我不知道如何用 scipy.谢谢...
解决方法
使用 scipy.lagrange 可以获得高阶多项式,而不是抛物线。
你可以试试curve_fit:
def parabola(x,a,b,c):
return (a * x + b) * x + c
from scipy.optimize import curve_fit
params,_ = curve_fit(parabola,x,y)
print(params)
[ 2.52716219e-04 -2.52482511e-02 5.97428457e+00]
这是一条抛物线,但顶点不在您需要的点上。
另一种选择是minimize:
def parabola_error(params):
err = parabola(x,*params) - y
return err.dot(err) # Sum of squares
from scipy.optimize import minimize
res = minimize(parabola_error,x0=np.array([1,1,1]),method='bfgs')
print(res.x)
[ 2.52323815e-04 -2.52090387e-02 5.97365419e+00]
实际上与 curve_fit 的结果相同,但我们现在可以使用 SLSQP
方法添加顶点位置作为附加约束,例如:
抛物线的顶点由 x0=-b/(2a)
,y0=c-b²/(4a)
给出:
vertex_x = x[2] # As required
vertex_y = y[2]
def vertex_x_error(params):
a,c = params
return - b / (2 * a) - vertex_x # Not squared!
def vertex_y_error(params):
a,c = params
return c - b * b / (4 * a) - vertex_y # Not squared!
constraints = ({'type': 'eq','fun': vertex_x_error},{'type': 'eq','fun': vertex_y_error})
# Using previous result as the seed
res = minimize(parabola_error,x0=res.x,method='SLSQP',constraints=constraints)
print(res.x)
[ 2.32725662e-04 -2.32725662e-02 5.95181415e+00]