【问题标题】:Prediction on new data with GLMNET and CARET - The number of variables in newx must be X使用 GLMNET 和 CARET 预测新数据 - newx 中的变量数必须为 X
【发布时间】:2021-12-25 15:26:51
【问题描述】:

我有一个数据集,我正在使用它进行 k 折交叉验证。

在每一折中,我都将数据拆分为训练和测试数据集。

对于数据集 X 的训练,我运行以下代码:

cv_glmnet <- caret::train(x = as.data.frame(X[curtrainfoldi, ]), y = y[curtrainfoldi, ],
                       method = "glmnet",
                       preProcess = NULL,
                       trControl = trainControl(method = "cv", number = 10),
                       tuneLength = 10)
    
   

查看'cv_glmnet'的类,返回'train'。

然后我想使用这个模型来预测测试数据集中的值,这是一个具有相同数量的变量(列)的矩阵

# predicting on test data 
yhat <- predict.train(cv_glmnet, newdata = X[curtestfoldi, ])   

但是,我一直遇到以下错误:

Error in predict.glmnet(modelFit, newdata, s = modelFit$lambdaOpt, type = "response") : 
  The number of variables in newx must be 210

我在 caret.predict 文档中注意到,它声明如下:

newdata 要预测的可选数据集。如果为 NULL,则 使用原始训练数据,但如果训练模型使用配方, 会发生错误。

我很困惑为什么会遇到这个错误。它与我如何定义新数据有关吗?我的数据具有正确数量的变量/列(与训练数据集相同),所以我不知道是什么导致了错误。

【问题讨论】:

  • 您的数据框很可能有问题。检查dim(as.data.frame(X[curtrainfoldi, ]))dim(X[curtrainfoldi, ])
  • @StupidWolf,感谢您的回复。这两个命令都返回确切的列数,据我了解应该是这种情况。关于可能出现什么问题的任何其他想法?
  • 是的,它应该......所以我无法重现你的错误......
  • 将 X[curtrainfoldi, ] 转换为 data.frame 似乎也解决了这个问题。我不确定为什么会这样,因为它仍然具有相同数量的列,但现在似乎一切正常。
  • 这是列名.. 我猜你要么在矩阵中没有列,要么你的列名被 as.data.frame 更改了

标签: r r-caret predict glmnet


【解决方案1】:

您收到错误是因为当您传递as.data.frame(X) 时您的列名发生了变化。如果您的矩阵没有列名,它会创建列名,并且模型在尝试预测时会期望这些。如果它有列名,那么其中一些可以更改:

library(caret)
library(tibble)

X =  matrix(runif(50*20),ncol=20)
y = rnorm(50)

cv_glmnet <- caret::train(x = as.data.frame(X), y = y,
                       method = "glmnet",
                       preProcess = NULL,
                       trControl = trainControl(method = "cv", number = 10),
                       tuneLength = 10)

yhat <- predict.train(cv_glmnet, newdata = X) 

Warning message:
In nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,  :
  There were missing values in resampled performance measures.
Error in predict.glmnet(modelFit, newdata, s = modelFit$lambdaOpt) : 
  The number of variables in newx must be 20 

如果你有列名,它可以工作

colnames(X) = paste0("column",1:ncol(X))
cv_glmnet <- caret::train(x = as.data.frame(X), y = y,
                       method = "glmnet",
                       preProcess = NULL,
                       trControl = trainControl(method = "cv", number = 10),
                       tuneLength = 10)

yhat <- predict.train(cv_glmnet, newdata = X)

【讨论】:

    猜你喜欢
    • 2018-05-17
    • 2016-09-29
    • 2013-08-04
    • 2015-10-31
    • 1970-01-01
    • 2021-03-28
    • 2015-02-12
    • 2019-09-09
    • 2014-09-27
    相关资源
    最近更新 更多