如何在调用curve_fit时打包我的numpy变量和数组?

问题描述

这是我重现该问题的独立代码

import numpy as np
from scipy.optimize import curve_fit


def find_vector_of_minor_axis_from_chunk(data):
    n = 20  # number of points
    time = np.linspace(0,2 * np.pi,n)

    guess_center_point = data.mean(1)
    guess_center_point = guess_center_point[np.newaxis,:].transpose()
    guess_a_phase = 0
    guess_b_phase = 0
    guess_a = 1
    guess_b = 1
    guess_a_axis_vector = np.array([[1],[0],[0]])
    guess_b_axis_vector = np.array([[0],[1],[0]])

    p0 = np.array([guess_center_point,guess_a,guess_a_axis_vector,guess_a_phase,guess_b,guess_b_axis_vector,guess_b_phase])

    def ellipse_func(t,center_point,a,a_axis_vector,a_phase,b,b_axis_vector,b_phase):
        return center_point + a * a_axis_vector * np.sin(t * a_phase) + b * b_axis_vector * np.sin(t + b_phase)

    popt,pcov = curve_fit(ellipse_func,time,data,p0=p0)
    center_point,b_phase = popt

    print(str(a_axis_vector,b_axis_vector))
    shorter_vector = a_axis_vector
    if np.abs(a_axis_vector) > np.aps(b_axis_vector):
        shorter_vector = b_axis_vector
    return shorter_vector


def main():
    data = np.array([[-4.62767933,-4.6275775,-4.62735346,-4.62719652,-4.62711625,-4.62717975,-4.62723845,-4.62722407,-4.62713901,-4.62708749,-4.62703238,-4.62689101,-4.62687185,-4.62694013,-4.62701082,-4.62700483,-4.62697488,-4.62686825,-4.62675683,-4.62675204],[-1.58625998,-1.58625039,-1.58619648,-1.58617611,-1.58620606,-1.5861833,-1.5861821,-1.58619169,-1.58615814,-1.58616893,-1.58613179,-1.58615934,-1.58611262,-1.58610782,-1.58614017,-1.58613059,-1.58612699,-1.58607428,-1.58610183],[-0.96714786,-0.96713827,-0.96715984,-0.96715145,-0.96716703,-0.96712869,-0.96716104,-0.96713228,-0.96719698,-0.9671838,-0.96717062,-0.96715744,-0.96707717,-0.96709275,-0.96706519,-0.96715026,-0.96711791,-0.96713588,-0.96714786]])

    print(str(find_vector_of_minor_axis_from_chunk(data)))

if __name__ == '__main__':
    main()

这给了我这个追溯:

Traceback (most recent call last):
  File "C:/Users/X/PycharmProjects/lissajous-achse/ellipse_fit.py",line 52,in <module>
    main()
  File "C:/Users/X/PycharmProjects/lissajous-achse/ellipse_fit.py",line 49,in main
    print(str(find_vector_of_minor_axis_from_chunk(data)))
  File "C:/Users/X/PycharmProjects/lissajous-achse/ellipse_fit.py",line 25,in find_vector_of_minor_axis_from_chunk
    popt,p0=p0)
  File "C:\Users\X\PycharmProjects\lissajous-achse\venv\lib\site-packages\scipy\optimize\minpack.py",line 763,in curve_fit
    res = leastsq(func,p0,Dfun=jac,full_output=1,**kwargs)
  File "C:\Users\X\PycharmProjects\lissajous-achse\venv\lib\site-packages\scipy\optimize\minpack.py",line 392,in leastsq
    raise TypeError('Improper input: N=%s must not exceed M=%s' % (n,m))
TypeError: Improper input: N=7 must not exceed M=3

Process finished with exit code 1

我的代码是第二个答案here的改编。简单包装变量here即可解决导致错误消息的问题。

为什么问题没有在提到的第二个答案中浮出水面?我如何打包由几个3d向量和单个标量组成的变量来解决此问题?我如何传递我的t,它是一个常数,不应该对其进行优化?

解决方法

显然,关于参数字段的长度,python相当聪明,这取决于最初的猜测。所以我可以只传入一个变量,然后像这样在函数中将其拆分:

import numpy as np
from scipy.optimize import minimize


def find_vector_of_minor_axis_from_chunk(data):
    n = 20  # number of points
    guess_center_point = data.mean(1)
    guess_center_point = guess_center_point[np.newaxis,:].transpose()
    guess_a_phase = 0.0
    guess_b_phase = 0.0
    guess_a = 1.0
    guess_b = 1.0
    guess_a_axis_vector = np.array([[1.0],[0.0],[0.0]])
    guess_b_axis_vector = np.array([[0.0],[1.0],[0.0]])

    p0 = np.array([guess_center_point,guess_a,guess_a_axis_vector,guess_a_phase,guess_b,guess_b_axis_vector,guess_b_phase])

    def ellipse_func(x,data):
        center_point = x[0]
        a = x[1]
        a_axis_vector = x[2]
        a_phase = x[3]
        b = x[4]
        b_axis_vector = x[5]
        b_phase = x[6]
        t = np.linspace(0,2 * np.pi,n)
        error = center_point + a * a_axis_vector * np.sin(t * a_phase) + b * b_axis_vector * np.sin(t + b_phase) - data
        error_sum = np.sum(error**2)
        print(str(error_sum))
        return error_sum

    popt,pcov = minimize(ellipse_func,p0,args=(data))
    center_point,a,a_axis_vector,a_phase,b,b_axis_vector,b_phase = popt

    print(str(a_axis_vector,b_axis_vector))
    shorter_vector = a_axis_vector
    if np.abs(a_axis_vector) > np.aps(b_axis_vector):
        shorter_vector = b_axis_vector
    return shorter_vector


def main():
    data = np.array([[-4.62767933,-4.6275775,-4.62735346,-4.62719652,-4.62711625,-4.62717975,-4.62723845,-4.62722407,-4.62713901,-4.62708749,-4.62703238,-4.62689101,-4.62687185,-4.62694013,-4.62701082,-4.62700483,-4.62697488,-4.62686825,-4.62675683,-4.62675204],[-1.58625998,-1.58625039,-1.58619648,-1.58617611,-1.58620606,-1.5861833,-1.5861821,-1.58619169,-1.58615814,-1.58616893,-1.58613179,-1.58615934,-1.58611262,-1.58610782,-1.58614017,-1.58613059,-1.58612699,-1.58607428,-1.58610183],[-0.96714786,-0.96713827,-0.96715984,-0.96715145,-0.96716703,-0.96712869,-0.96716104,-0.96713228,-0.96719698,-0.9671838,-0.96717062,-0.96715744,-0.96707717,-0.96709275,-0.96706519,-0.96715026,-0.96711791,-0.96713588,-0.96714786]])

    print(str(find_vector_of_minor_axis_from_chunk(data)))

if __name__ == '__main__':
    main()

我还修复了向量中初始值的一些浮点数与整数错误。

但是现在我得到了另一个错误:

Traceback (most recent call last):
  File "C:/Users/X/PycharmProjects/lissajous-achse/ellipse_fit.py",line 61,in <module>
    main()
  File "C:/Users/X/PycharmProjects/lissajous-achse/ellipse_fit.py",line 58,in main
    print(str(find_vector_of_minor_axis_from_chunk(data)))
  File "C:/Users/X/PycharmProjects/lissajous-achse/ellipse_fit.py",line 34,in find_vector_of_minor_axis_from_chunk
    popt,args=(data))
  File "C:\Users\X\PycharmProjects\lissajous-achse\venv\lib\site-packages\scipy\optimize\_minimize.py",line 604,in minimize
    return _minimize_bfgs(fun,x0,args,jac,callback,**options)
  File "C:\Users\X\PycharmProjects\lissajous-achse\venv\lib\site-packages\scipy\optimize\optimize.py",line 1063,in _minimize_bfgs
    if isinf(rhok):  # this is patch for numpy
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我猜

具有多个元素的数组的真值是不明确的。 使用a.any()或a.all()

是一些内部错误,源于内部决策矩阵如何进行。我不知道我是怎么引起的以及如何修复它。当我弄清楚如何正确完成操作后,我将返回并编辑此答案。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...