【发布时间】:2018-06-17 05:58:53
【问题描述】:
我需要拟合scipy.optimize.curve_fit 一些看起来像图中点的数据。我使用一个函数y(x)(参见下面的定义),它为x<x0 提供一个常数y(x)=c,否则为多项式(例如第二条倾斜线y1 = mx+q)。
我对参数(x0, c, m, q)给出了一个合理的初步猜测,如图所示。拟合结果表明,除第一个参数x0 外,所有参数均已优化。
为什么会这样?
是我如何定义函数testfit(x, *p),其中x0 (=p[0]) 出现在另一个函数中?
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
# generate some data:
x = np.linspace(0,100,1000)
y1 = np.repeat(0, 500)
y2 = x[500:] - 50
y = np.concatenate((y1,y2))
y = y + np.random.randn(len(y))
def testfit(x, *p):
''' piecewise function used to fit
it's a constant (=p[1]) for x < p[0]
or a polynomial for x > p[0]
'''
x = x.astype(float)
y = np.piecewise(x, [x < p[0], x >= p[0]], [p[1], lambda x: np.poly1d(p[2:])(x)])
return y
# initial guess, one horizontal and one tilted line:
p0_guess = (30, 5, 0.3, -10)
popt, pcov = curve_fit(testfit, x, y, p0=p0_guess)
print('params guessed : '+str(p0_guess))
print('params from fit : '+str(popt))
plt.plot(x,y, '.')
plt.plot(x, testfit(x, *p0_guess), label='initial guess')
plt.plot(x, testfit(x, *popt), label='final fit')
plt.legend()
输出
params guessed : (30, 5, 0.3, -10)
params from fit : [ 30. 0.04970411 0.80106256 -34.17194401]
OptimizeWarning: Covariance of the parameters could not be estimated category=OptimizeWarning)
【问题讨论】:
-
我相信您面临的问题本质上与this question 相似。您有两种选择:1. 使用不同的拟合算法。 2. 尽量避免曲线中的垂直边缘,方法是 a 给它一个小斜率,或者 b 让多项式与常数平滑连接值)。
-
正如@kazemakase 建议的那样,
curve_fit不会很好地处理像您的 p[0] 这样的离散变量,因为它无法为该变量创建偏导数。
标签: python scipy curve-fitting