【问题标题】:Python: Use general scipy.optimize.curve_fit functionPython:使用通用 scipy.optimize.curve_fit 函数
【发布时间】:2021-12-04 11:40:43
【问题描述】:

我想在 python 中曲线拟合一些数据。我的程序如下所示:

from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

def lin(x, a, b,c):
    return a*x+b

def exp(x, a, b, c):
    return a*np.exp(b*x)+c

def ln(x, a, b, c):
    return a*np.log(b+x)+c

x_dummy = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
y_dummy = np.array([9.2, 9.9, 10.0, 11.2, 10.2, 12.6, 10.0, 11.6, 12.2])



popt, _ = curve_fit(lin, x_dummy[:-2], y_dummy[:-1])

y_approx = lin(x_dummy, popt[0], popt[1], popt[2])

print(y_approx[-1])



print(popt)
print(mean_squared_error(y_dummy[:-1], y_approx[:-2]))


plt.plot(x_dummy[:-1], y_dummy, color='blue')
plt.plot(x_dummy, y_approx, color='green')
plt.show()

我现在的目标是一个通用函数,称为 fn,它可以有一些参数,例如从某种意义上说,作为字符串,调用

popt, _ = curve_fit(fn('lin' or 'exp' or 'ln'), x_dummy[:-2], y_dummy[:-1])

意思相同

popt, _ = curve_fit(lin or exp or ln, x_dummy[:-2], y_dummy[:-1])

背景:我想生成一些数组 = ['lin', 'exp', 'ln'] 并遍历所有三种可能的曲线拟合并计算再现平方误差的最小值。

【问题讨论】:

    标签: python curve-fitting scipy-optimize


    【解决方案1】:

    找到了一些方法,但也许是更简单的方法:

    from scipy.optimize import curve_fit
    import matplotlib.pyplot as plt
    from sklearn.metrics import mean_squared_error
    
    class FunctionCollector():
        def __init__(self):
            self.name = 'lin'
    
        def setFunc(self, name):
            self.name = name
    
        def lin(self, x, a, b, c):
            return a*x+b
    
        def exp(self, x, a, b, c):
            return a*np.exp(b*x)+c
    
        def ln(self, x, a, b, c):
            return a*np.log(b+x)+c
    
        def fn(self, x, a, b, c):
            if self.name == 'lin':
                return self.lin(x, a,b,c)
            elif self.name == 'exp':
                return self.exp(x,a,b,c)
            elif self.name == 'ln':
                return self.ln(x,a,b,c)
            return 0
    
    
    
    def l(x,a,b,c):
        return a * x + b
    x_dummy = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
    y_dummy = np.array([9.2, 9.9, 10.0, 11.2, 10.2, 12.6, 10.0, 11.6, 12.2])
    
    
    #noise = 5*np.random.normal(size=y_dummy.size)
    #y_dummy = y_dummy + noise
    
    f = FunctionCollector()
    
    popt, _ = curve_fit(f.fn, x_dummy[:-2], y_dummy[:-1])
    y_approx = f.fn(x_dummy, popt[0], popt[1], popt[2])
    
    print(y_approx[-1])
    
    
    
    print(popt)
    print(mean_squared_error(y_dummy[:-1], y_approx[:-2]))
    
    
    plt.plot(x_dummy[:-1], y_dummy, color='blue')
    plt.plot(x_dummy, y_approx, color='green')
    plt.show()
    

    【讨论】:

      猜你喜欢
      • 2018-11-22
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2014-03-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多