【问题标题】:Comparing statsmodel predictions with actual y-values (indexing issue)将 statsmodel 预测与实际 y 值进行比较(索引问题)
【发布时间】:2017-10-10 21:54:43
【问题描述】:

目标:我想计算拟合多元线性回归模型所做预测的检验误差。

问题:这是我的代码。它旨在将线性回归模型拟合到训练数据,然后根据 X_test 变量预测 y 值(价格):

X.insert(0, 'constant', 1)   # insert constant column
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)
lm_sm = sm.OLS(y_train, X_train).fit()

y_pred = pd.DataFrame()   # dataframe for predictions vs actual y-values
y_pred['predictions'] = lm_sm.predict(X_test)

print y_test.sort_index().head()
print y_pred.sort_index().head()

代码输出如下:

       price
6   257500.0
17  485000.0
23  252700.0
25  233000.0
26  937000.0
     predictions
0  509428.615367
1  324403.584917
2  477385.431339
3  484962.235105
4  827039.820936

比较预测和实际价格,这显然是不对的。 predict() 方法不保留我的训练/测试拆分中的索引。因此,当我将预测价格与实际价格进行比较时,我无法确定我比较的是正确的值。

我想到的唯一解决方案(我不确定这是否正确)是在进行预测时对 X_test 进行排序,即y_pred['predictions'] = lm_sm.predict(X_test.sort_index())。预测看起来更符合我的预期(注意这是第一个线性回归/基准测试,因此尚未应用特征工程):

       price
6   257500.0
12  310000.0
18  189000.0
25  233000.0
29  719000.0
     predictions
0  259985.788272
1  590648.478023
2  339621.126287
3  316402.199424
4  635513.611634

然后,我将根据这些排序的数据帧执行测试错误计算。这是正确的吗?有没有更清洁的方法来做到这一点?一个我不知道的方法?任何帮助/想法将不胜感激,谢谢!

【问题讨论】:

    标签: python-2.7 pandas statsmodels


    【解决方案1】:

    我实际上不认为订购有什么问题。与y_pred 的干净索引相比,y_test 的混合索引是混淆的根源。

    当您在数据集 (X, y) 上使用 train_test_split 时,它显然会打乱,这就是为什么 y_test 有一个打乱的索引子集。

    当您执行lm_sm.predict(X_test) 时,输出是一个普通的 numpy 数组。它不是带有索引的 pandas 对象,因此索引信息已经丢失。另外,在您的代码中,您将结果存储在一个新的 y_pred 数据框中。在那个新的y_pred 中,索引将是一个新的自动增量:0、1、2 ...

    所以不要使用sort_index(),您可以确定y_pred 将与X_testy_test 对齐。

    编辑。希望这能更好地说明问题。

    import numpy as np
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from statsmodels.regression.linear_model import OLS
    
    X = pd.DataFrame(np.random.random((60, 3)))
    y = pd.DataFrame(np.random.random((60, 1)))
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)
    
    print(X_test)
    

    测试集的特点:

               0         1         2
    2   0.547993  0.479149  0.495539
    48  0.332964  0.857635  0.501391
    23  0.380500  0.377257  0.088766
    35  0.045725  0.432096  0.239523
    52  0.254861  0.207215  0.985722
    37  0.099525  0.205250  0.054000
    22  0.426227  0.253524  0.336110
    43  0.716443  0.006443  0.423447
    49  0.146820  0.803366  0.390921
    6   0.127666  0.848561  0.936604
    46  0.303034  0.548064  0.852688
    33  0.516726  0.977396  0.829725
    

    测试集的目标是print(y_test)

               0
    2   0.123253
    48  0.494307
    23  0.312021
    35  0.939558
    52  0.958955
    37  0.681215
    22  0.181427
    43  0.907552
    49  0.589316
    6   0.613305
    46  0.947220
    33  0.696609
    

    指数是随机排列的,但它们是一致的。现在做:

    lm_sm = OLS(y_train, X_train).fit()
    y_pred = pd.DataFrame()   # dataframe for predictions vs actual y-values
    y_pred['predictions'] = lm_sm.predict(X_test)
    
    # Print this directly
    print(lm_sm.predict(X_test))
    

    最后一行只是一个普通的 numpy 数组:

    [ 0.44549342  0.44973765  0.24465328  0.17840542  0.42329909  0.09567253
      0.30675321  0.38496281  0.33836597  0.49959203  0.47488055  0.63751567]
    

    当您查看新的 y_pred 数据框时,索引是新的 0、1、2 ... print(y_pred)

        predictions
    0      0.445493
    1      0.449738
    2      0.244653
    3      0.178405
    4      0.423299
    5      0.095673
    6      0.306753
    7      0.384963
    8      0.338366
    9      0.499592
    10     0.474881
    11     0.637516
    

    您可能对这些索引与 y_test 的索引不匹配感到惊讶,但正如我所展示的,predict() 函数返回一个普通的 numpy 数组,并且没有任何东西将生成的 y_pred 连接到原始索引了。不过,您可以确定一切都是对齐的。

    【讨论】:

    • predict 在最新版本的 statsmodels 中,如果提供的 exog 具有索引,则应返回具有适当索引的 pandas Series。在过渡期间的特殊情况下存在一些错误,但预测返回的索引应该是 statsmodels 0.9 中正确索引的 pandas.Series。
    【解决方案2】:

    请注意,train_test_split 现在有可选参数shuffle=False 以避免洗牌。这应该可以解决您在 df 中对齐的问题。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-07-27
      • 1970-01-01
      • 2019-06-23
      • 1970-01-01
      • 2021-03-19
      • 1970-01-01
      • 2020-09-19
      • 1970-01-01
      相关资源
      最近更新 更多