【问题标题】:How can I empose the ntree parameter into the train() function of caret package?如何将 ntree 参数添加到 caret 包的 train() 函数中?
【发布时间】:2020-02-24 16:04:06
【问题描述】:

我正在使用以下函数对我的数据集上的随机森林算法进行交叉验证。但是,ntree 引发了一个错误,说它没有在函数中使用。尽管我之前在关于这个问题的一个线程中看到了这种用法作为推荐评论,但它对我不起作用。这是我的代码:

cv_rf_class1 <- train(y_train_u ~ ., x_train_u , 
                      method ="cforest", 
                      trControl = trainControl(method = "cv", 
                                               number = 10, 
                                               verboseIter = TRUE),  
                                               ntree = 100))

如果我无法更改 ntree 参数,它会在函数中使用 500 棵树作为默认值,并且会引发另一个错误(下标越界),因此我无法让它解决我的问题。如何解决此问题以使我的功能正常工作?

【问题讨论】:

    标签: r machine-learning random-forest r-caret


    【解决方案1】:

    ntree 必须是 train 的参数,而不是 trainControl 的参数,因为您在这里使用过它;来自train的文档:

    ...
    传递给分类或回归例程的参数(例如randomForest)。如果在此处传递调整参数的值,则会发生错误。

    另请注意,您没有以正确的形式传递数据; train 期望数据为 (x, y),而不是您传递的数据(公式和矩阵的错误组合)。

    总而言之,将您的 train 呼叫更改为:

    cv_rf_class1 <- train(x_train_u, y_train_u,
                          method ="cforest", 
                          ntree = 100,
                          trControl = trainControl(method = "cv", 
                                                   number = 10, 
                                                   verboseIter = TRUE))
    

    更新(在 cmets 之后)

    好吧,cforest 似乎特别不会接受 ntree 参数,因为与原始的 randomForest 包相比,这 不是各个包 (docs) 的底层 cforest 函数中的树。

    caret Github repo中的相关示例所示,正确的方法是:

    cv_rf_class1 <- train(x_train_u, y_train_u,
                          method ="cforest", 
                          trControl = trainControl(method = "cv", 
                                                   number = 10, 
                                                   verboseIter = TRUE),
                          controls = party::cforest_unbiased(ntree = 100))
    

    适应cforest.R,我们得到:

    library(caret)
    library(plyr)
    library(recipes)
    library(dplyr)
    
    model <- "cforest"
    
    set.seed(2)
    training <- twoClassSim(50, linearVars = 2)
    testing <- twoClassSim(500, linearVars = 2)
    trainX <- training[, -ncol(training)]
    trainY <- training$Class
    
    rec_cls <- recipe(Class ~ ., data = training) %>%
      step_center(all_predictors()) %>%
      step_scale(all_predictors())
    
    seeds <- vector(mode = "list", length = nrow(training) + 1)
    seeds <- lapply(seeds, function(x) 1:20)
    
    cctrl1 <- trainControl(method = "cv", number = 3, returnResamp = "all",
                           classProbs = TRUE, 
                           summaryFunction = twoClassSummary,
                           seeds = seeds)
    
    set.seed(849)
    test_class_cv_model <- train(trainX, trainY, 
                                 method = "cforest", 
                                 trControl = cctrl1,
                                 metric = "ROC", 
                                 preProc = c("center", "scale"),
                                 controls = party::cforest_unbiased(ntree = 20)) # WORKS OK
    
    test_class_pred <- predict(test_class_cv_model, testing[, -ncol(testing)])
    test_class_prob <- predict(test_class_cv_model, testing[, -ncol(testing)], type = "prob")
    
    head(test_class_pred)
    # [1] Class2 Class2 Class2 Class1 Class1 Class1
    # Levels: Class1 Class2
    
    head(test_class_prob)
    #      Class1    Class2
    # 1 0.4996686 0.5003314
    # 2 0.4333222 0.5666778
    # 3 0.3625118 0.6374882
    # 4 0.5373396 0.4626604
    # 5 0.6174159 0.3825841
    # 6 0.5327283 0.4672717
    

    sessionInfo() 的输出:

    R version 3.6.1 (2019-07-05)
    Platform: x86_64-w64-mingw32/x64 (64-bit)
    Running under: Windows 7 x64 (build 7601) Service Pack 1
    
    Matrix products: default
    
    locale:
    [1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252    LC_MONETARY=English_United Kingdom.1252
    [4] LC_NUMERIC=C                            LC_TIME=English_United Kingdom.1252    
    
    attached base packages:
    [1] stats     graphics  grDevices utils     datasets  methods   base     
    
    other attached packages:
    [1] recipes_0.1.7   dplyr_0.8.3     plyr_1.8.4      caret_6.0-84    ggplot2_3.2.1   lattice_0.20-38
    

    【讨论】:

      猜你喜欢
      • 2018-03-16
      • 1970-01-01
      • 2014-05-23
      • 1970-01-01
      • 2013-11-06
      • 2018-03-11
      • 2017-06-01
      • 2017-07-20
      • 2012-05-16
      相关资源
      最近更新 更多