带约束的多项式插值

问题描述

我正在尝试使用 scipy 函数创建多项式插值。我创建了一些函数来遍历我的所有数据点(实际上是 5 个数据点),但我想对最小值添加一个约束。我的代码如下:

import numpy as np
from scipy.interpolate import *
import matplotlib.pyplot as plt

x = np.array([10,25,50,75,90])
y = np.array([5.7239,5.53,5.37,5.43,5.789])

poly = lagrange(x,y)
xpol = np.linspace(x[0],x[4],1000)
ypol = poly(xpol)

plt.scatter(x,y,marker='s',c='r')
plt.plot(xpol,ypol)

我得到的情节如下:

enter image description here

问题是点 (50,5.37) 应该是我的抛物线的最小值(参见图中的黑线),我不知道如何用 scipy.谢谢...

解决方法

使用 scipy.lagrange 可以获得高阶多项式,而不是抛物线。

你可以试试curve_fit

def parabola(x,a,b,c):
    return (a * x + b) * x + c

from scipy.optimize import curve_fit
params,_ = curve_fit(parabola,x,y)
print(params)

[ 2.52716219e-04 -2.52482511e-02  5.97428457e+00]

enter image description here

这是一条抛物线,但顶点不在您需要的点上。

另一种选择是minimize

def parabola_error(params):
    err = parabola(x,*params) - y
    return err.dot(err)            # Sum of squares

from scipy.optimize import minimize
res = minimize(parabola_error,x0=np.array([1,1,1]),method='bfgs')
print(res.x)

[ 2.52323815e-04 -2.52090387e-02  5.97365419e+00]

实际上与 curve_fit 的结果相同,但我们现在可以使用 SLSQP 方法添加顶点位置作为附加约束,例如:

抛物线的顶点由 x0=-b/(2a),y0=c-b²/(4a) 给出:

vertex_x = x[2]     # As required
vertex_y = y[2]

def vertex_x_error(params):
    a,c = params
    return - b / (2 * a) - vertex_x        # Not squared!

def vertex_y_error(params):
    a,c = params
    return c - b * b / (4 * a) - vertex_y  # Not squared!

constraints = ({'type': 'eq','fun': vertex_x_error},{'type': 'eq','fun': vertex_y_error})

# Using previous result as the seed
res = minimize(parabola_error,x0=res.x,method='SLSQP',constraints=constraints)
print(res.x)

[ 2.32725662e-04 -2.32725662e-02  5.95181415e+00]

enter image description here