【问题标题】:simultaneous fitting python parameter sharing同时拟合python参数共享
【发布时间】:2016-12-06 20:19:41
【问题描述】:

我有六个数据集,我希望同时拟合所有六个数据集,六个数据集之间共有两个参数,一个单独拟合。

我打算为数据集拟合一个简单的 ax**2+bx+c 多项式,其中 a 和 b 在六个数据集之间共享,而偏移量 c 在六个数据集之间不共享。

因此,我在数据集之间拟合了一个公共斜率,但偏移量可变。

我完全有能力单独拟合它们,但是由于每个数据集之间的斜率相似,因此使用同时拟合将大大改善偏移量 c 的误差。

我通常使用 scipy.optmize.curve_fit 进行拟合。

import numpy as np
from scipy.optimize import curve_fit

def func(x,a,b,c):
    return (a*(x**2)+b*x+c)
def fit(x,y,yerr):
    popt, pcov = curve_fit(func,x,y,p0=[-0.6,5,-12],sigma=yerr)
    chi=np.sum( ((func(x, *popt) - y) / yerr)**2)
    redchi=(chi-1)/len(y)
    return popt,pcov,redchi,len(y)

我正在处理 6 组:x,xerr,y,yerr 每个集合的 len(x) 和 len(y) 都不同。

我知道我必须连接数据集并以这种方式拟合它们。

如果有人可以提供任何建议或帮助,我相信这对我和社区都有好处。

【问题讨论】:

  • 是 Python 问题还是数学问题?
  • 你是说你将有六个值c(比如c_1...c_6)和ab各一个值,这样dataset_i 的模型将是 a*x**2 + b*x + c_i?
  • @inspectorG4dget 是的,a 和 b 在所有 6 个数据集中的拟合值相同,c 将有 6 个拟合值。
  • @Laurent LAPORTE,目前主要是python问题。有没有办法适应共享参数?你能把数据集放在一个矩阵中同时拟合它们吗?
  • 我不知道如何用scipy.curve_fit 做到这一点,但我可以编写一个简单的双染色体遗传算法来做到这一点

标签: python curve-fitting least-squares data-fitting


【解决方案1】:

因为我有类似的拟合问题,所以我做了symfit来处理这种情况。所以我很抱歉无耻地建议我自己的包,但我认为这对你很有帮助。它包含曲线拟合,但提供了一个符号界面,使事情变得更容易。

您的问题可以这样解决:

from symfit import variables, parameters, Fit

xs = variables('x_1, x_2, x_3, x_4, x_5, x_6')
ys = variables('y_1, y_2, y_3, y_4, y_5, y_6')

a, b = parameters('a, b')
cs = parameters(', '.join('c_{}'.format(i) for i in range(1, 6)))

model_dict = {
    y: a * x**2 + b * x + c
        for x, y, c in zip(xs, ys, cs)
}

fit = Fit(model_dict, x_1=x1_data, x_2=x2_data, ..., y_1=y1_data, ..., sigma_y_1=y1_err, sigma_y_2=y2_err, ...)
fit_result = fit.execute()
print(fit_result)

查看文档了解更多信息: http://symfit.readthedocs.io/en/latest/fitting_types.html#global-fitting

附言为了对您的参数进行初始猜测,每个Parameter 对象都带有一个.value 属性,该属性包含初始猜测。例如,a.value = -0.6

编辑: 以前需要一些额外的解决方法,这解释了下面的一些讨论。不过,我现在发布了一个新的symfit 版本,上面的代码在其中运行。

【讨论】:

  • 我检查了你的python包,看起来真的很整洁!我会看看我将来是否可以使用它:-)!对于“Fit”函数:我给它来自 x1,x2,x3.. y1,y2,y3.. 的数据集。但是,似乎每个数组的大小需要相同,否则我得到经典的“操作数”无法与形状 (25,) (46,) 一起广播”。配件的形状是否需要相同?谢谢。
  • 很高兴你喜欢它!如果我正确地完成了我的工作,它们不应该是相同的大小。您确定您的x_iy_isigma_y_i 具有相同的形状吗?在模型的每个组件中,它们必须具有相同的形状。如果是,您可以尝试像这样手动调用模型:fit.model([your data here], a=-0.6, b=5, c_1=12, etc.)。这应该单独评估模型的每个组件,并且这个调用应该有效。 symfit.readthedocs.io/en/latest/tutorial.html#named-models
  • 我正在使用输入:
  • 这是输入:'fit = Fit(model_dict, x_1=x[0], x_2=x[1],x_3=x[2], x_4=x[3],x_5= x[4], x_6=x[5], y_1=y[0], y_2=y[1], y_3=y[2], y_4=y[3], y_5=y[4], y_6=y [5], sigma_y_1=yerr[0], sigma_y_2=yerr[1],sigma_y_3=yerr[2],sigma_y_4=yerr[3],sigma_y_5=yerr[4],sigma_y_6=yerr[5]) fit_result = fit. execute() print(fit_result)' 使得 'x[i],y[i],yerr[i]' 都是相同的形状。
  • 在玩了一会儿之后,我收回了我的声明。在当前版本的symfit 中,它们的长度必须相同。因此,我为您实施了一个快速解决方法,并将更改我的答案以将其提供给您。在清理和概括之后,我也会把它放在新版本的symfit 中。感谢您指出这一点并帮助我改进这个包:)。
【解决方案2】:

感谢所有建议,我似乎找到了一种方法,可以同时使用 a,b 和 c1,c2,c3,c4,c5,c6 作为参数,其中 a 和 b 是共享的。

下面是我最后使用的代码:

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit

x=[vt,bt,ut,w1t,m2t,w2t]
y=[vmag,bmag,umag,w1mag,m2mag,w2mag]
xerr=[vterr,uterr,bterr,w1terr,m2terr,w2terr]
yerr=[vmagerr,umagerr,bmagerr,w1magerr,m2magerr,w2magerr]

def poly(x_, a, b, c1, c2, c3, c4, c5, c6):
    #all this is just to split x_data into the original parts of x
    l= len(x[0])
    l1= len(x[1])
    l2= len(x[2])
    l3= len(x[3])
    l4= len(x[4])
    l5= len(x[5])
    s=l+l1
    s1=l2+s
    s2=l3+s1
    s3=l4+s2
    s4=l5+s3


    a= np.hstack([
a*x_[:l]**2 + b*x_[:l] +c1,
a*x_[l:(s)]**2 + b*x_[l:(s)] +c2,
a*x_[(s):(s1)]**2 + b*x_[(s):(s1)] +c3,
a*x_[(s1):(s2)]**2 + b*x_[(s1):(s2)] +c4,
a*x_[(s2):(s3)]**2 + b*x_[(s2):(s3)] +c5,
a*x_[(s3):(s4)]**2 + b*x_[(s3):(s4)] +c6
])       
    print a
    return a 
x_data = np.hstack([x[0],x[1],x[2],x[3],x[4],x[5]])
y_data = np.hstack([y[0],y[1],y[2],y[3],y[4],y[5]])

(a, b, c1, c2, c3, c4, c5, c6), _ = curve_fit(poly, x_data, y_data)

抱歉,如果这是糟糕的编码!我的方法很粗暴!但是,它确实做得很好!

下面是我的结果。

Fitted results from simultaneous fitting with shared parameters

【讨论】:

    【解决方案3】:

    一种可能性是更改要拟合的函数,使每个数据集都有自己的“a”和“b”参数,并带有一个共同的“c”,类似于这个粗略的代码 sn-p:

    def func(x,a1,b1,a2,b2,a3,b3,a4,b4,a5,b5,a6,b6, c):
        if x in data_set_1:
            return (a1*(x**2)+b1*x+c)
        if x in data_set_2:
            return (a2*(x**2)+b2*x+c)
        if x in data_set_3:
            return (a3*(x**2)+b3*x+c)
        if x in data_set_4:
            return (a4*(x**2)+b4*x+c)
        if x in data_set_5:
            return (a5*(x**2)+b5*x+c)
        if x in data_set_6:
            return (a6*(x**2)+b6*x+c)
        raise Exception('Data outside fitting range') # just in case
    

    【讨论】:

    • 你得到了共享参数和个人参数。
    猜你喜欢
    • 1970-01-01
    • 2021-03-22
    • 1970-01-01
    • 1970-01-01
    • 2013-12-18
    • 2019-12-04
    • 2021-05-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多