【问题标题】:How can I use curve_fit for functions that involve case-splitting?如何将 curve_fit 用于涉及大小写拆分的函数?
【发布时间】:2022-12-17 11:42:42
【问题描述】:

我想将 curve_fit 用于涉及大小写拆分的函数。
但是 python 抛出错误。

curve_fit不支持这样的功能吗?还是函数定义有问题?

例子)

from scipy.optimize import curve_fit
import numpy as np

def slope_devided_by_cases(x,a,b):
    if x < 4:
        return a*x + b
    else:
        return 4*a + b

data_x =  [1,2,3,4,5,6,7,8,9]  # x
data_y  = [45,46,42,36,27,23,21,13,11]  # y
coef, cov = curve_fit(slope_devided_by_cases, data_x, data_y)

错误)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
C:\Users\Lisa~1\AppData\Local\Temp/ipykernel_1516/1012358816.py in <module>
     10 data_x =  [1,2,3,4,5,6,7,8,9]  # x
     11 data_y  = [45,46,42,36,27,23,21,13,11]  # y
---> 12 coef, cov = curve_fit(slope_devided_by_cases, data_x, data_y)

~\anaconda3\lib\site-packages\scipy\optimize\minpack.py in curve_fit(f, xdata, ydata, p0, sigma, absolute_sigma, check_finite, bounds, method, jac, **kwargs)
    787         # Remove full_output from kwargs, otherwise we're passing it in twice.
    788         return_full = kwargs.pop('full_output', False)
--> 789         res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
    790         popt, pcov, infodict, errmsg, ier = res
    791         ysize = len(infodict['fvec'])

~\anaconda3\lib\site-packages\scipy\optimize\minpack.py in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
    408     if not isinstance(args, tuple):
    409         args = (args,)
--> 410     shape, dtype = _check_func('leastsq', 'func', func, x0, args, n)
    411     m = shape[0]
    412 

~\anaconda3\lib\site-packages\scipy\optimize\minpack.py in _check_func(checker, argname, thefunc, x0, args, numinputs, output_shape)
     22 def _check_func(checker, argname, thefunc, x0, args, numinputs,
     23                 output_shape=None):
---> 24     res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
     25     if (output_shape is not None) and (shape(res) != output_shape):
     26         if (output_shape[0] != 1):

~\anaconda3\lib\site-packages\scipy\optimize\minpack.py in func_wrapped(params)
    483     if transform is None:
    484         def func_wrapped(params):
--> 485             return func(xdata, *params) - ydata
    486     elif transform.ndim == 1:
    487         def func_wrapped(params):

C:\Users\Lisa~1\AppData\Local\Temp/ipykernel_1516/1012358816.py in slope_devided_by_cases(x, a, b)
      3 
      4 def slope_devided_by_cases(x,a,b):
----> 5     if x < 4:
      6         return a*x + b
      7     else:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我想将 curve_fit 用于涉及大小写拆分的函数,例如上面的示例。

【问题讨论】:

  • x 是一个类似于data_x 的数组时,您期望if x&lt;a: 会发生什么。 fit 将像 slope_devided_by_cases(data_x,1,1) 一样调用您的函数,然后尝试将该结果与 data_y 进行比较。看看data_x&lt;1。这对你来说代表着什么?

标签: python-3.x scipy curve-fitting


【解决方案1】:

问题是 x &lt; 4 不是布尔标量值,因为 curve_fit 将使用 np.ndarray x (您给定的 x 数据点)而不是标量值来评估您的函数。因此,x &lt; 4 将为您提供一个布尔值数组。

也就是说,您可以使用 NumPy 的矢量化操作重写您的函数:

def slope_devided_by_cases(x,a,b):
    return (x < 4) * (a*x + b) + (x >= 4) * (4*a+b)

或者,您可以使用 np.where 作为 if-else 方法的矢量化替代方法:

def slope_devided_by_cases(x,a,b):
    return np.where(x < 4, a*x + b, 4+a+b)

【讨论】:

    猜你喜欢
    • 2016-03-04
    • 1970-01-01
    • 2023-02-22
    • 2013-11-06
    • 1970-01-01
    • 1970-01-01
    • 2010-12-09
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多