【问题标题】:curve_fit doesn't work properly with 4 parameters使用 4 个参数时,curve_fit 无法正常工作
【发布时间】:2017-02-04 13:49:00
【问题描述】:

运行以下代码,

x = np.array([50.849937, 53.849937, 56.849937, 59.849937, 62.849937, 65.849937, 68.849937, 71.849937, 74.849937, 77.849937, 80.849937, 83.849937, 86.849937, 89.849937, 92.849937])
y = np.array([410.67800, 402.63800, 402.63800, 386.55800, 330.27600, 217.71400, 72.98990, 16.70860, 8.66833, 40.82920, 241.83400, 386.55800, 394.59800, 394.59800, 402.63800])
def f(om, a, i , c):
       return a - i*np.exp(- c* (om-74.)**2)
par, cov = curve_fit(f, x, y)
stdev = np.sqrt(np.diag(cov) )

生成此图表,

具有以下参数和标准差:

par =   [ 4.09652163e+02, 4.33961227e+02, 1.58719772e-02]
stdev = [ 1.46309578e+01, 2.44878171e+01, 2.40474753e-03]

但是,当尝试将此数据拟合到以下函数时:

def f(om, a, i , c, omo):
       return a - i*np.exp(- c* (om-omo)**2)

它不起作用,它会产生一个标准偏差

stdev = [inf, inf, inf, inf, inf]

有没有办法解决这个问题?

【问题讨论】:

    标签: python numpy matplotlib scipy curve-fitting


    【解决方案1】:

    看起来它没有收敛(参见thisthis)。尝试添加一个初始条件,

    par, cov = curve_fit(f, x, y, p0=[1.,1.,1.,74.])
    

    导致

    par = [ 4.11892318e+02, 4.36953868e+02, 1.55741131e-02, 7.32560690e+01])
    stdev = [ 1.17579445e+01, 1.94401006e+01, 1.86709423e-03, 2.62952690e-01]
    

    【讨论】:

    • 必须设置所有条件:'startfirst = [86.5, -438., -411., 73.5]',不能只使用一个,但谢谢
    【解决方案2】:

    您可以根据数据计算初始条件:

    %matplotlib inline
    import pylab as pl
    
    import numpy as np
    from scipy.optimize import curve_fit
    
    x = np.array([50.849937, 53.849937, 56.849937, 59.849937, 62.849937, 65.849937, 68.849937, 71.849937, 74.849937, 77.849937, 80.849937, 83.849937, 86.849937, 89.849937, 92.849937])
    y = np.array([410.67800, 402.63800, 402.63800, 386.55800, 330.27600, 217.71400, 72.98990, 16.70860, 8.66833, 40.82920, 241.83400, 386.55800, 394.59800, 394.59800, 402.63800])
    
    def f(om, a, i , c, omo):
           return a - i*np.exp(- c* (om-omo)**2)
    
    par, cov = curve_fit(f, x, y, p0=[y.max(), y.ptp(), 1, x[np.argmin(y)]])
    stdev = np.sqrt(np.diag(cov) )
    
    pl.plot(x, y, "o")
    x2 = np.linspace(x.min(), x.max(), 100)
    pl.plot(x2, f(x2, *par))
    

    【讨论】:

      猜你喜欢
      • 2018-03-25
      • 2020-12-22
      • 1970-01-01
      • 1970-01-01
      • 2014-02-22
      • 1970-01-01
      • 2013-03-06
      • 2018-04-24
      • 1970-01-01
      相关资源
      最近更新 更多