【问题标题】:Train test split in `r`'s `caret` package`r` 的 `caret` 包中的训练测试拆分
【发布时间】:2016-06-13 14:26:33
【问题描述】:

我对 rcaret 包很熟悉,但是,它来自其他编程语言,让我非常困惑。

我现在要做的是一个相当简单的机器学习工作流程,即:

  1. 获取一个训练集,在我的例子中是 iris 数据集
  2. 将其拆分为训练和测试集(80-20 拆分)
  3. 对于从120 的每个k,在训练集上训练k 最近邻分类器
  4. 在测试集上测试它

我了解如何执行第一部分,因为 iris 已经加载。然后,第二部分通过调用来完成

a <- createDataPartition(iris$Species, list=FALSE)
training <- iris[a,]
test <- iris[-a,]

现在,我也知道我可以通过调用来训练模型

library(caret)
knnFit <- train()
knnFit <- train(Species~., data=training, method="knn")

但是,这将导致r 已经对参数k 执行了一些优化。当然,我可以限制 k 方法应该尝试的值,例如

knnFit <- train(Species~., data=training, method="knn", tuneGrid=data.frame(k=1:20))

它工作得很好,但它仍然不是我想要它做的。这段代码现在可以为每个k

  1. test 获取引导样本。
  2. 使用给定样本评估k-nn 方法的性能

我想要它做什么:

  1. 对于每个 k在我之前构建的同一训练集上训练模型
  2. 在我之前构建的同一测试集上评估性能。

所以我需要类似的东西

knnFit <- train(Species~., training_data=training, test_data=test, method="knn", tuneGrid=data.frame(k=1:20))

但这当然行不通。

我知道我应该对 trainControl 参数做一些事情,但我看到它可能的方法是:

"boot", "boot632", "cv", "repeatedcv", "LOOCV", "LGOCV", "none"

这些似乎都不能满足我的要求。

【问题讨论】:

    标签: r r-caret


    【解决方案1】:

    请通读caret website 了解一切如何运作。或者阅读 Max Kuhn 所写的“Applied Predictive Modeling”一书,了解有关插入符号如何工作的更多信息。

    粗略地说,trainControl 包含用于训练函数的各种参数集,例如交叉验证设置、要应用的指标(ROC / RMSE)、采样、预处理等。

    在火车中,您可以设置其他设置,例如网格搜索。我扩展了您的代码示例,因此它可以工作。确保检查 createDataPartition 的工作方式,因为默认设置将数据分成两半。

    library(caret)
    
    a <- createDataPartition(iris$Species, p = 0.8, list=FALSE)
    training <- iris[a,]
    test <- iris[-a,]
    
    knnFit <- train(Species ~ ., 
                    data = training, 
                    method="knn",  
                    tuneGrid=data.frame(k=1:20))
    
    knn_pred <- predict(knnFit, newdata = test)
    

    根据评论编辑:

    你想要的东西是不可能用一个火车对象来实现的。 Train 将使用 tunegrid 找到最佳 k 并在 finalModel 中使用该结果。这个 finalModel 将用于进行预测。

    如果您想了解所有 k 的概览,您可能不想使用 caret 的 train 函数,而是自己编写一个函数。也许像下面这样。请注意,knn3 是来自插入符号的 knn 模型。

    k <- 20
    knn_fit_list <- list()
    knn_pred_list <- list()
    
    for (i in 1:k) {
      knn_fit_list[[i]] <- knn3(Species ~ ., 
                                data = training, 
                                k = i)
      knn_pred_list[[i]] <- predict(knn_fit_list[[i]], newdata = test, type = "class")
    
    }
    

    knn_fit_list 将包含指定数量 k 的所有拟合模型。 knn_pred_list 将包含所有预测。

    【讨论】:

    • 如果我的问题有点不清楚,我很抱歉,但我认为这不能回答它。我了解 tuneGrid 参数,这就是我在示例代码中使用它的原因。我想要的是有一个与您的代码类似的代码,但它会为训练集中的k 的每个 值返回一个预测。我希望我的解释更清楚。
    【解决方案2】:

    如果我正确理解了这个问题,这可以在插入符号中使用 LGOCV(Leave-group-out-CV = 重复训练/测试拆分)并设置训练百分比 p = 0.8 和训练/测试的重复次数来完成如果您真的只想要一个模型适合每个 k 在测试集上测试,则拆分为 number = 1。设置 number > 1 将重复评估模型在 number 不同的训练/测试分割上的性能。

    data(iris)
    library(caret)
    set.seed(123)
    mod <- train(Species ~ ., data = iris, method = "knn", 
                 tuneGrid = expand.grid(k=1:20),
                 trControl = trainControl(method = "LGOCV", p = 0.8, number = 1,
                                          savePredictions = T))
    

    如果savePredictions = T,测试集上不同模型所做的所有预测都在mod$pred中。注意rowIndex:这些是已经采样到测试集中的行。这些对于k 的所有不同值都是相等的,因此每次都使用相同的训练/测试集。

    > head(mod$pred)
        pred    obs rowIndex k  Resample
    1 setosa setosa        5 1 Resample1
    2 setosa setosa        6 1 Resample1
    3 setosa setosa       10 1 Resample1
    4 setosa setosa       12 1 Resample1
    5 setosa setosa       16 1 Resample1
    6 setosa setosa       17 1 Resample1
    > tail(mod$pred)
             pred       obs rowIndex  k  Resample
    595 virginica virginica      130 20 Resample1
    596 virginica virginica      131 20 Resample1
    597 virginica virginica      135 20 Resample1
    598 virginica virginica      137 20 Resample1
    599 virginica virginica      145 20 Resample1
    600 virginica virginica      148 20 Resample1 
    

    除非需要某种嵌套验证过程,否则无需在插入符号之外手动构建训练/测试集。您还可以通过plot(mod) 绘制k 的不同值的验证曲线。

    【讨论】:

    • 看起来这就是我想要的,是的。谢谢!
    • 我就是喜欢caret 让代码如此简洁的方式。当然,要为这种简洁的语法付出代价,必须在更高层次上理解工作流,这对初学者来说可能会很吓人。但是,当您想快速完成很多工作时,它会带来丰厚的回报。
    • 我发现编写自己的“插入符号”类型函数更容易。更有趣,我可以更好地控制交叉验证和参数优化的工作方式。我注意到人们会盲目地使用 15 种不同的 ML 类型启动插入符号运行,并使用默认参数希望将尾巴钉在驴身上。相反,他们最终会得到 15 个糟糕的模型,并且不明白为什么他们会看到糟糕的表现。我的建议?学习一些建模技术(glmnet、xgboost、RF)到你编写自己的简历脚本的地步,你将有更大的成功机会
    猜你喜欢
    • 2020-11-25
    • 1970-01-01
    • 2018-04-22
    • 2021-06-28
    • 2019-07-21
    • 2021-01-01
    • 2018-12-21
    • 1970-01-01
    • 2019-03-16
    相关资源
    最近更新 更多