【问题标题】:How to calculate error in Polynomial Linear Regression?如何计算多项式线性回归中的误差?
【发布时间】:2017-10-10 16:41:06
【问题描述】:

我正在尝试计算我正在使用的训练数据的错误率。

我相信我计算的错误不正确。公式如图:

y计算如下:

我在49 行的函数fitPoly(M) 中计算这个。我相信我计算错误y(x(n)),但我不知道还能做什么。


以下是最小、完整且可验证的示例。

import numpy as np
import matplotlib.pyplot as plt

dataTrain = [[2.362761180904257019e-01, -4.108125266714775847e+00],
[4.324296163702689988e-01,  -9.869308732049049127e+00],
[6.023323504115264404e-01,  -6.684279243433971729e+00],
[3.305079685397107614e-01,  -7.897042003779912278e+00],
[9.952423271981121200e-01,  3.710086310489402628e+00],
[8.308127402955634011e-02,  1.828266768673480147e+00],
[1.855495407116576345e-01,  1.039713135916495501e+00],
[7.088332047815845138e-01,  -9.783208407540947560e-01],
[9.475723071629885697e-01,  1.137746192425550085e+01],
[2.343475721257285427e-01,  3.098019704040922750e+00],
[9.338350584099475160e-02,  2.316408265530458976e+00],
[2.107903139601833287e-01,  -1.550451474833406396e+00],
[9.509966727520677843e-01,  9.295029459100994984e+00],
[7.164931165416982273e-01,  1.041025972594300075e+00],
[2.965557300301902011e-03,  -1.060607693351102121e+01]]

def strip(L, xt):
    ret = []
    for i in L:
        ret.append(i[xt])
    return ret

x1 = strip(dataTrain, 0)
y1 = strip(dataTrain, 1)

# HELP HERE

def getY(m, w, D):
    y = w[0]
    y += np.sum(w[1:] * D[:m])
    return y

# HELP ABOVE

def dataMatrix(X, M):
    Z = []
    for x in range(len(X)):
        row = []
        for m in range(M + 1):
            row.append(X[x][0] ** m)
        Z.append(row)
    return Z

def fitPoly(M):
    t = []
    for i in dataTrain:
        t.append(i[1])
    w, _, _, _ = np.linalg.lstsq(dataMatrix(dataTrain, M), t)
    w = w[::-1]
    errTrain = np.sum(np.subtract(t, getY(M, w, x1)) ** 2)/len(x1)
    print('errTrain: %s' % (errTrain))
    return([w, errTrain])

#fitPoly(8)

def plotPoly(w):
    plt.ylim(-15, 15)
    x, y = zip(*dataTrain)
    plt.plot(x, y, 'bo')
    xw = np.arange(0, 1, .001)
    yw = np.polyval(w, xw)
    plt.plot(xw, yw, 'r')

#plotPoly(fitPoly(3)[0])

def bestPoly():
    m = 0
    plt.figure(1)
    plt.xlim(0, 16)
    plt.ylim(0, 250)
    plt.xlabel('M')
    plt.ylabel('Error')
    plt.suptitle('Question 3: training and Test error')
    while m < 16:
        plt.figure(0)
        plt.subplot(4, 4, m + 1)
        plotPoly(fitPoly(m)[0])
        plt.figure(1)
        plt.plot(fitPoly(m)[1])
        #plt.plot(fitPoly(m)[2])
        m+= 1
    plt.figure(3)
    plt.xlabel('t')
    plt.ylabel('x')
    plt.suptitle('Question 3: best-fitting polynomial (degree = 8)')
    plotPoly(fitPoly(8)[0])
    print('Best M: %d\nBest w: %s\nTraining error: %s' % (8, fitPoly(8)[0], fitPoly(8)[1], ))

bestPoly()

【问题讨论】:

  • 为什么你认为你计算的误差不正确?
  • @KevinK。 “你应该发现测试误差通常(如果不是总是)大于训练误差。此外,训练误差应该随着 M 的增加而减小,当 M = 15 时达到零。测试误差应该在最初趋于减小然后开始增加,当 M 接近 15 时变得非常大。”根据讲义,最大误差也应低于 250。
  • 问候@AndrewRaleigh。 np.sum(w[1:] * D[:m]) 应该工作吗? getY 的作用是什么?
  • 我认为getY 应该写成更像y = w[0] + np.sum([w[i+1]*D**i for i in range(m)])D 是单个值而不是数组。
  • @AndrewRaleigh 首先,我认为w[1:] * D[:m] 会出错。所以你希望getYD(训练数据)作为输入域输出y(x)的所有值......?我假设D 将是第一列中的值,t(n) 将是dataTrain 的第二列中的值..?

标签: python numpy linear-regression polynomial-math


【解决方案1】:

更新:此解决方案使用 numpy 的 np.interp 将点连接为一种“最合适的”。然后,我们使用您的误差函数来查找此插值线与每个多项式次数的预测 y 值之间的差异。

import numpy as np
import matplotlib.pyplot as plt
import itertools

dataTrain = [
  [2.362761180904257019e-01, -4.108125266714775847e+00],
  [4.324296163702689988e-01,  -9.869308732049049127e+00],
  [6.023323504115264404e-01,  -6.684279243433971729e+00],
  [3.305079685397107614e-01,  -7.897042003779912278e+00],
  [9.952423271981121200e-01,  3.710086310489402628e+00],
  [8.308127402955634011e-02,  1.828266768673480147e+00],
  [1.855495407116576345e-01,  1.039713135916495501e+00],
  [7.088332047815845138e-01,  -9.783208407540947560e-01],
  [9.475723071629885697e-01,  1.137746192425550085e+01],
  [2.343475721257285427e-01,  3.098019704040922750e+00],
  [9.338350584099475160e-02,  2.316408265530458976e+00],
  [2.107903139601833287e-01,  -1.550451474833406396e+00],
  [9.509966727520677843e-01,  9.295029459100994984e+00],
  [7.164931165416982273e-01,  1.041025972594300075e+00],
  [2.965557300301902011e-03,  -1.060607693351102121e+01]
  ]

data = np.array(dataTrain)
data = data[data[:, 0].argsort()]

X,y = data[:, 0], data[:, 1]

fig,ax = plt.subplots(4, 4)
indices = list(itertools.product([0,1,2,3], repeat=2))
for i,loc in enumerate(indices, start=1):
  xx = np.linspace(X.min(), X.max(), 1000)
  yy = np.interp(xx, X, y)
  w = np.polyfit(X, y, i)
  y_pred = np.polyval(w, xx)
  ax[loc].scatter(X, y)
  ax[loc].plot(xx, y_pred)
  ax[loc].plot(xx, yy, 'r--')

  error = np.square(yy - y_pred).sum() / X.shape[0]
  print(error)

plt.show()

打印出来:

2092.19807848
1043.9400277
1166.94550318
252.238810889
225.798905379
155.785478366
125.662973726
143.787869281
6553.66570273
10805.6609259
15577.8686283
13536.1755299
108074.871771
213513916823.0
472673224393.0
1.01198058355e+12

从视觉上看,它描绘了这一点:

从这里开始,只需将这些错误保存到列表中并找到最小值。

【讨论】:

  • 感谢@Jarad,您的编辑是正确的,因为当 M > 8 时,尖峰会变得更加明显。尽管如此,您的代码很棒,有没有办法使用 np.linalg.lstsq 而不是 polyfit /polyval?
  • 我更新了我的答案。我不认为你可以用np.linalg.lstsq 做到这一点,因为这不是一个 (M x N) 矩阵;它是一个 X 值。但我可能是错的。
  • 您可以使用函数 dataMatrix 创建 (M x N) 矩阵。我想出了如何使用linalg来做到这一点。如果你不介意,还有一件事。误差不应该大于 250,虽然数字看起来是对的,但它们似乎相当大?
  • 我认为这些数字要大得多,因为我在 xx 变量中人为地制作了 1000 个 X 数据点,以绘制更平滑的线条并显示波浪。使用简单的 X 和 y 坐标不足以可视化真正的多项式线。因此,在我看来,输出的error 仍然与数据相关,只是有更多的数据点,这在计算差异时会增加更多的误差。要尝试的一件事可能是将这个数字从 1000 更改为 100 或 1000000,看看它是否给出了不同的“最佳”。
  • 感谢您的回复。尽管由于polyfit 的顺序太高,我收到了很多警告。
【解决方案2】:

我可以贡献:

  def pol_y(x, w):
        y = 0; power = 0;
        for i in w:
            y += i*(x**power);
            power += 1;
        return y

M 被隐式包含,因为它是w 的最终索引。所以如果w = [0, 0, 1],那么pol_y(x, w)就等于f(x) = x^2。

如果要映射dataTrain 的第一列:

get_Y = [pol_y(i, w) for i in x1 ] 

误差可以通过以下方式计算

vec_error = [(y1[i] - getY[i])**2 for i in range(0, len(y1)];
train_error = np.sum(vec_error)/len(y1);

希望这会有所帮助。

【讨论】:

  • 我认为您的pol_yy(x) 的计算方式,谢谢!不过对y(x(n)) 有什么想法吗?尝试您的 get_Y 导致的错误幅度比我预期的要大,所以我不完全确定就是这样。
  • 我添加了一些东西。所以向量/列表w 将是最佳拟合多项式的系数值,是吗?
猜你喜欢
  • 2020-08-09
  • 1970-01-01
  • 2019-12-05
  • 1970-01-01
  • 2020-09-14
  • 1970-01-01
  • 2014-02-20
  • 2021-01-19
  • 2021-11-11
相关资源
最近更新 更多