【问题标题】:Calculating cross validation manually gives different result手动计算交叉验证会给出不同的结果
【发布时间】:2021-01-08 12:34:03
【问题描述】:

我们来获取数据:

set.seed(42)
y <- rnorm(125)
x <- data.frame(runif(125), rexp(125))

我想对其执行 2 折交叉验证。所以:

library(caret)
model <- train(y ~ .,
  data = cbind(y, x), method = "lm",
  trControl = trainControl(method = "cv", number = 2)
)
model 

Linear Regression 

125 samples
  2 predictor

No pre-processing
Resampling: Cross-Validated (2 fold) 
Summary of sample sizes: 63, 62 
Resampling results:

  RMSE      Rsquared     MAE      
  1.091108  0.002550859  0.8472947

Tuning parameter 'intercept' was held constant at a value of TRUE

我想手动获取上面的这个 RMSE 值,以确保我完全理解交叉验证。

我目前的工作

正如我在上面看到的,我的样本分为:62(1 折)和 63(2 折)。

#Training first model basing on first fold
model_1 <- lm(y[1:63] ~ ., data = x[1:63, ])
#Calculating RMSE for the first model
RMSE_1 <- RMSE(y[64:125], predict(model_1, newdata = x[64:125, ]))
#Training second model basing on second fold
model_2 <- lm(y[64:125] ~ ., data = x[64:125, ])
#Calculating RMSE for the second model
RMSE_2 <- RMSE(y[1:63], predict(model_1, newdata = x[1:63, ]))
mean(c(RMSE_1, RMSE_2))
 1.023411

我的问题是 - 为什么我得到不同的 RMSE ?这个误差太大了,不能被视为估计误差——当然他们是以另一种方式计算的。你知道我在做什么不同吗?

【问题讨论】:

    标签: r regression linear-regression cross-validation manual-testing


    【解决方案1】:

    您使用的逻辑是正确的,但您需要进行两处更改:

    1. Caret 将创建自己的 2 组数据用于训练。它不会是 1:63、64:125,但插入符号会根据种子生成它们
    2. RMSE_2 中有错字,应该是model_2

    这是更新后的代码:

    # the folds are kept in this part of the output (trial and error to find it haha)
    model$control$index
    f1 <- model$control$index[[1]]
    f2 <- model$control$index[[2]]
    
    # re-do your calculations but using the fold indexes, plus typo for RMSE_2
    model_1 <- lm(y[f1] ~ ., data = x[f1, ])
    #Calculating RMSE for the first model
    RMSE_1 <- RMSE(y[f2], predict(model_1, newdata = x[f2, ]))
    #Training second model basing on second fold
    model_2 <- lm(y[f2] ~ ., data = x[f2, ])
    #Calculating RMSE for the second model
    RMSE_2 <- RMSE(y[f1], predict(model_2, newdata = x[f1, ]))
    
    # matches now
    mean(c(RMSE_1, RMSE_2))
    

    【讨论】:

      猜你喜欢
      • 2015-05-27
      • 2020-08-31
      • 2017-12-22
      • 2016-12-10
      • 2013-02-21
      • 1970-01-01
      • 1970-01-01
      • 2021-04-23
      • 2018-09-16
      相关资源
      最近更新 更多