【问题标题】:Why does SciPy's curve_fit function care about the type of xdata?为什么 SciPy 的 curve_fit 函数关心 xdata 的类型?
【发布时间】:2018-12-17 22:18:07
【问题描述】:

我试图使用 SciPy 的 curve_fit 拟合一些数据,结果非常奇怪。所以我尝试了尝试并测试并发现了xdata类型的问题。当xdataint 类型时,结果会变得很奇怪。但这并不适用于所有功能f。我用最高 6 阶的多项式进行了测试。从 3 阶及更高阶开始,结果变得很奇怪。

小例子:

import numpy as np
from scipy.optimize import curve_fit

def poly4(x, a, b, c, d, e):
    return a*np.power(x,4) + b*np.power(x,3) + c*np.power(x,2) + d*x + e

x = np.linspace(0, 9.6, 2400)
y = poly4(x, 0.03, -0.68, 5.6, -22, 1351)

x1 = np.arange(0, 2400, 1, dtype=np.dtype('float'))
x2 = np.arange(0, 2400, 1, dtype=np.dtype('int'))

popt1,_ = curve_fit(poly4, x1, y)
popt2,_ = curve_fit(poly4, x2, y)

f1 = poly4(x1, *popt1)
f2 = poly4(x2, *popt2)

绘制这些值

import matplotlib.pyplot as plt
plt.plot(f1, label='f1, float range')
plt.plot(f2, label='f2, int range')
plt.legend()
plt.show()

给予

蓝线正是结果的样子。查看curve_fit 输出与

print(popt1)
print(popt2)

给予

[ 9.05733149e-12 -4.92513534e-08 9.73032914e-05 -9.17048770e-02 1.35100000e+03]

[ 3.52993170e-11 -1.52725549e-10 9.38577666e-06 -3.58806105e-02 1.34272489e+03]

为什么这些结果如此不同?嗯,很明显,因为xdata 的数据类型。但是curve_fit为什么要关心xdata的数据类型呢?我看不出这背后的原因,也没有找到任何关于它的文档。

编辑:在python 3.6.3scipy 0.19.1python 3.7.1scipy 1.1.0 上测试。两者都在 Windows 上。

【问题讨论】:

  • 似乎是版本问题。在 python 3.6.5 上,intfloat 类型都给出了相同的结果(这是您所期望的)。我使用的 scipy 版本是1.1.0
  • @Bazingaa 这很奇怪。查看我的编辑。
  • 我的系统上的结果相同,但 x1x2 的结果相同,方法是将 power 替换为 float_power
  • 2400**4 导致np.int 溢出,这就是为什么您会看到足够大的x 的锯齿状图案。事实上,对于 4 的幂,溢出早在x=216
  • @Brenlla 谢谢!这解释了它。

标签: python numpy matplotlib scipy


【解决方案1】:

关心x的类型的不是curve_fit,而是你的函数poly4。 Numpy 在其操作中保留数组的类型。由于取整数的 n 次方,很快就会遇到整数溢出,从而产生意想不到的结果。

例如查看 np.power(x,3) 的输出:

x = np.arange(0,2400,1, dtype=np.int32)
plt.plot(x,np.power(x,3))

【讨论】:

    【解决方案2】:

    您和无法重现您的问题的每个人都遇到的问题是np.dtype('int') 的大小在不同的平台上是不同的。如果您将 x1x2 的声明替换为:

    x1 = np.arange(0, 2400, 1, dtype=np.dtype('float'))
    x2 = np.arange(0, 2400, 1, dtype=np.int32)
    

    那么无论平台如何,您都可以始终如一地重现奇怪的输出:

    最初的问题是由于np.int32 太小而无法处理您正在计算的一些非常大的数字,并且中间计算的值溢出。所以结果:

    poly4(np.arange(2000, 2010, dtype=np.int32), 0.03, -0.68, 5.6, -22, 1351)
    # array([4.60917546e+08, 3.82703937e+08, 4.34772636e+08, 3.59427040e+08,
       4.14366625e+08, 3.41894792e+08, 3.99711018e+08, 3.30118704e+08,
       3.90817330e+08, 3.24110298e+08])
    

    看起来与以下结果非常不同:

    poly4(np.arange(2000, 2010, dtype=np.int64), 0.03, -0.68, 5.6, -22, 1351)
    # array([4.74582357e+11, 4.75534936e+11, 4.76488948e+11, 4.77444394e+11,
       4.78401277e+11, 4.79359597e+11, 4.80319357e+11, 4.81280557e+11,
       4.82243198e+11, 4.83207283e+11])
    

    【讨论】:

      猜你喜欢
      • 2016-04-24
      • 2021-02-16
      • 2016-03-08
      • 2021-03-19
      • 1970-01-01
      • 1970-01-01
      • 2021-02-16
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多