【问题标题】:Skip fitting the final model with caret跳过用插入符号拟合最终模型
【发布时间】:2019-03-28 17:52:37
【问题描述】:

有时当我用插入符号拟合模型时,我真的很想看看它使用我选择的重采样方法(例如交叉验证)如何执行。

当我对基于完整训练数据的“最终模型”不感兴趣时​​,我想避免拟合它。这实际上只是在开发过程中多次节省宝贵的时间。

在使用插入符号时,有什么方法可以跳过拟合最终模型?我在caret::trainControlcaret::train 中没有看到任何相关的论点。

【问题讨论】:

    标签: r r-caret


    【解决方案1】:

    似乎确实没有直接实现这一点的论点。不过,有几个候选解决方案。

    1. selectionFunction 作为trainControl 的参数,根据候选模型在精度、RMSE 等方面的性能(没有参数调整时只有一个候选)选择最终模型。 selectionFunction 之类的 function(x, ...) NAfunction(x, ...) NULL 失败。然而,像function(x, ...) -1 这样的东西确实部分起作用:没有返回警告或错误,最终模型尝试拟合。最终结果似乎取决于模型。

    2. 另一个有趣的trainControl 参数是indexFinal

      一个可选的整数向量,指示哪些样本用于 重新采样后拟合最终模型。如果为 NULL,则整个数据集为 用过。

      将其设置为NA 似乎在大多数模型中都失败了,除了 kNN。将其设置为 1:10 之类的东西适合最终模型,如果参数足够少,仅使用这十个观察值。因此,将其设置为 1:100 之类的东西在很多情况下都可以工作,而且花费的时间很少。

    3. 您当然可以更改train 函数本身。下面我只添加一个参数fitFinal,默认为TRUE,并在拟合最终模型时检查它是否为TRUE。如果fitFinal == FALSE,那么

      finalModel <- list(fit = NULL, preProc = NULL)
      finalTime <- 0
      

      其他一切似乎都很顺利。至于覆盖实际的train.default 函数,你应该在之后运行

      environment(myTrain) <- environment(caret:::train.default)
      assignInNamespace("train.default", myTrain, ns = "caret")
      

      所以,我们有

      myTrain <- function (x, y, method = "rf", preProcess = NULL, ..., weights = NULL, fitFinal = TRUE,
                           metric = ifelse(is.factor(y), "Accuracy", "RMSE"), maximize = ifelse(metric %in%
                                                                                                  c("RMSE", "logLoss", "MAE"), FALSE, TRUE), trControl = trainControl(),
                           tuneGrid = NULL, tuneLength = ifelse(trControl$method ==
                                                                  "none", 1, 3))
      {
        startTime <- proc.time()
        rs_seed <- sample.int(.Machine$integer.max, 1L)
        if (is.null(colnames(x)))
          stop("Please use column names for `x`", call. = FALSE)
        if (is.character(y))
          y <- as.factor(y)
        if (!is.numeric(y) & !is.factor(y)) {
          stop("Please make sure `y` is a factor or numeric value.",
               call. = FALSE)
        }
        if (is.list(method)) {
          minNames <- c("library", "type", "parameters", "grid",
                        "fit", "predict", "prob")
          nameCheck <- minNames %in% names(method)
          if (!all(nameCheck))
            stop(paste("some required components are missing:",
                       paste(minNames[!nameCheck], collapse = ", ")),
                 call. = FALSE)
          models <- method
          method <- "custom"
        }
        else {
          models <- getModelInfo(method, regex = FALSE)[[1]]
          if (length(models) == 0)
            stop(paste("Model", method, "is not in caret's built-in library"),
                 call. = FALSE)
        }
        checkInstall(models$library)
        for (i in seq(along = models$library)) do.call("requireNamespaceQuietStop",
                                                       list(package = models$library[i]))
        if (any(names(models) == "check") && is.function(models$check)) {
          software_check <- models$check(models$library)
        }
        paramNames <- as.character(models$parameters$parameter)
        funcCall <- match.call(expand.dots = TRUE)
        modelType <- get_model_type(y)
        if (!(modelType %in% models$type))
          stop(paste("wrong model type for", tolower(modelType)),
               call. = FALSE)
        if (grepl("^svm", method) & grepl("String$", method)) {
          if (is.vector(x) && is.character(x)) {
            stop("'x' should be a character matrix with a single column for string kernel methods",
                 call. = FALSE)
          }
          if (is.matrix(x) && is.numeric(x)) {
            stop("'x' should be a character matrix with a single column for string kernel methods",
                 call. = FALSE)
          }
          if (is.data.frame(x)) {
            stop("'x' should be a character matrix with a single column for string kernel methods",
                 call. = FALSE)
          }
        }
        if (modelType == "Regression" & length(unique(y)) == 2)
          warning(paste("You are trying to do regression and your outcome only has",
                        "two possible values Are you trying to do classification?",
                        "If so, use a 2 level factor as your outcome column."))
        if (modelType != "Classification" & !is.null(trControl$sampling))
          stop("sampling methods are only implemented for classification problems",
               call. = FALSE)
        if (!is.null(trControl$sampling)) {
          trControl$sampling <- parse_sampling(trControl$sampling)
        }
        if (any(class(x) == "data.table"))
          x <- as.data.frame(x)
        check_dims(x = x, y = y)
        n <- if (class(y)[1] == "Surv")
          nrow(y)
        else length(y)
        parallel_check("RWeka", models)
        parallel_check("keras", models)
        if (!is.null(preProcess) && !(all(names(preProcess) %in%
                                          ppMethods)))
          stop(paste("pre-processing methods are limited to:",
                     paste(ppMethods, collapse = ", ")), call. = FALSE)
        if (modelType == "Classification") {
          classLevels <- levels(y)
          attributes(classLevels) <- list(ordered = is.ordered(y))
          xtab <- table(y)
          if (any(xtab == 0)) {
            xtab_msg <- paste("'", names(xtab)[xtab == 0], "'",
                              collapse = ", ", sep = "")
            stop(paste("One or more factor levels in the outcome has no data:",
                       xtab_msg), call. = FALSE)
          }
          if (trControl$classProbs && any(classLevels != make.names(classLevels))) {
            stop(paste("At least one of the class levels is not a valid R variable name;",
                       "This will cause errors when class probabilities are generated because",
                       "the variables names will be converted to ",
                       paste(make.names(classLevels), collapse = ", "),
                       ". Please use factor levels that can be used as valid R variable names",
                       " (see ?make.names for help)."), call. = FALSE)
          }
          if (metric %in% c("RMSE", "Rsquared"))
            stop(paste("Metric", metric, "not applicable for classification models"),
                 call. = FALSE)
          if (!trControl$classProbs && metric == "ROC")
            stop(paste("Class probabilities are needed to score models using the",
                       "area under the ROC curve. Set `classProbs = TRUE`",
                       "in the trainControl() function."), call. = FALSE)
          if (trControl$classProbs) {
            if (!is.function(models$prob)) {
              warning("Class probabilities were requested for a model that does not implement them")
              trControl$classProbs <- FALSE
            }
          }
        }
        else {
          if (metric %in% c("Accuracy", "Kappa"))
            stop(paste("Metric", metric, "not applicable for regression models"),
                 call. = FALSE)
          classLevels <- NA
          if (trControl$classProbs) {
            warning("cannnot compute class probabilities for regression")
            trControl$classProbs <- FALSE
          }
        }
        if (trControl$method == "oob" & is.null(models$oob))
          stop("Out of bag estimates are not implemented for this model",
               call. = FALSE)
        trControl <- withr::with_seed(rs_seed, make_resamples(trControl,
                                                              outcome = y))
        if (is.logical(trControl$savePredictions)) {
          trControl$savePredictions <- if (trControl$savePredictions)
            "all"
          else "none"
        }
        else {
          if (!(trControl$savePredictions %in% c("all", "final",
                                                 "none")))
            stop("`savePredictions` should be either logical or \"all\", \"final\" or \"none\"",
                 call. = FALSE)
        }
        if (!is.null(preProcess)) {
          ppOpt <- list(options = preProcess)
          if (length(trControl$preProcOptions) > 0)
            ppOpt <- c(ppOpt, trControl$preProcOptions)
        }
        else ppOpt <- NULL
        if (is.null(tuneGrid)) {
          if (!is.null(ppOpt) && length(models$parameters$parameter) >
              1 && as.character(models$parameters$parameter) !=
              "parameter") {
            pp <- list(method = ppOpt$options)
            if ("ica" %in% pp$method)
              pp$n.comp <- ppOpt$ICAcomp
            if ("pca" %in% pp$method)
              pp$thresh <- ppOpt$thresh
            if ("knnImpute" %in% pp$method)
              pp$k <- ppOpt$k
            pp$x <- x
            ppObj <- do.call("preProcess", pp)
            tuneGrid <- models$grid(x = predict(ppObj, x), y = y,
                                    len = tuneLength, search = trControl$search)
            rm(ppObj, pp)
          }
          else {
            tuneGrid <- models$grid(x = x, y = y, len = tuneLength,
                                    search = trControl$search)
            if (trControl$search != "grid" && tuneLength < nrow(tuneGrid))
              tuneGrid <- tuneGrid[1:tuneLength, , drop = FALSE]
          }
        }
        if (grepl("adaptive", trControl$method) & nrow(tuneGrid) ==
            1) {
          stop(paste("For adaptive resampling, there needs to be more than one",
                     "tuning parameter for evaluation"), call. = FALSE)
        }
        dotNames <- hasDots(tuneGrid, models)
        if (dotNames)
          colnames(tuneGrid) <- gsub("^\\.", "", colnames(tuneGrid))
        tuneNames <- as.character(models$parameters$parameter)
        goodNames <- all.equal(sort(tuneNames), sort(names(tuneGrid)))
        if (!is.logical(goodNames) || !goodNames) {
          stop(paste("The tuning parameter grid should have columns",
                     paste(tuneNames, collapse = ", ", sep = "")), call. = FALSE)
        }
        if (trControl$method == "none" && nrow(tuneGrid) != 1)
          stop("Only one model should be specified in tuneGrid with no resampling",
               call. = FALSE)
        trControl$yLimits <- if (is.numeric(y))
          get_range(y)
        else NULL
        if (trControl$method != "none") {
          if (is.function(models$loop) && nrow(tuneGrid) > 1) {
            trainInfo <- models$loop(tuneGrid)
            if (!all(c("loop", "submodels") %in% names(trainInfo)))
              stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'",
                   call. = FALSE)
            lengths <- unlist(lapply(trainInfo$submodels, nrow))
            if (all(lengths == 0))
              trainInfo$submodels <- NULL
          }
          else trainInfo <- list(loop = tuneGrid)
          num_rs <- if (trControl$method != "oob")
            length(trControl$index)
          else 1L
          if (trControl$method %in% c("boot632", "optimism_boot",
                                      "boot_all"))
            num_rs <- num_rs + 1L
          if (is.null(trControl$seeds) || all(is.na(trControl$seeds))) {
            seeds <- sample.int(n = 1000000L, size = num_rs *
                                  nrow(trainInfo$loop) + 1L)
            seeds <- lapply(seq(from = 1L, to = length(seeds),
                                by = nrow(trainInfo$loop)), function(x) {
                                  seeds[x:(x + nrow(trainInfo$loop) - 1L)]
                                })
            seeds[[num_rs + 1L]] <- seeds[[num_rs + 1L]][1L]
            trControl$seeds <- seeds
          }
          else {
            if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds))) {
              numSeeds <- unlist(lapply(trControl$seeds, length))
              badSeed <- (length(trControl$seeds) < num_rs +
                            1L) || (any(numSeeds[-length(numSeeds)] < nrow(trainInfo$loop))) ||
                (numSeeds[length(numSeeds)] < 1L)
              if (badSeed)
                stop(paste("Bad seeds: the seed object should be a list of length",
                           num_rs + 1, "with", num_rs, "integer vectors of size",
                           nrow(trainInfo$loop), "and the last list element having at least a",
                           "single integer"), call. = FALSE)
              if (any(is.na(unlist(trControl$seeds))))
                stop("At least one seed is missing (NA)", call. = FALSE)
            }
          }
          if (trControl$method == "oob") {
            perfNames <- metric
          }
          else {
            testSummary <- evalSummaryFunction(y, wts = weights,
                                               ctrl = trControl, lev = classLevels, metric = metric,
                                               method = method)
            perfNames <- names(testSummary)
          }
          if (!(metric %in% perfNames)) {
            oldMetric <- metric
            metric <- perfNames[1]
            warning(paste("The metric \"", oldMetric, "\" was not in ",
                          "the result set. ", metric, " will be used instead.",
                          sep = ""))
          }
          if (trControl$method == "oob") {
            tmp <- oobTrainWorkflow(x = x, y = y, wts = weights,
                                    info = trainInfo, method = models, ppOpts = preProcess,
                                    ctrl = trControl, lev = classLevels, ...)
            performance <- tmp
            perfNames <- colnames(performance)
            perfNames <- perfNames[!(perfNames %in% as.character(models$parameters$parameter))]
            if (!(metric %in% perfNames)) {
              oldMetric <- metric
              metric <- perfNames[1]
              warning(paste("The metric \"", oldMetric, "\" was not in ",
                            "the result set. ", metric, " will be used instead.",
                            sep = ""))
            }
          }
          else {
            if (trControl$method == "LOOCV") {
              tmp <- looTrainWorkflow(x = x, y = y, wts = weights,
                                      info = trainInfo, method = models, ppOpts = preProcess,
                                      ctrl = trControl, lev = classLevels, ...)
              performance <- tmp$performance
            }
            else {
              if (!grepl("adapt", trControl$method)) {
                tmp <- nominalTrainWorkflow(x = x, y = y, wts = weights,
                                            info = trainInfo, method = models, ppOpts = preProcess,
                                            ctrl = trControl, lev = classLevels, ...)
                performance <- tmp$performance
                resampleResults <- tmp$resample
              }
              else {
                tmp <- adaptiveWorkflow(x = x, y = y, wts = weights,
                                        info = trainInfo, method = models, ppOpts = preProcess,
                                        ctrl = trControl, lev = classLevels, metric = metric,
                                        maximize = maximize, ...)
                performance <- tmp$performance
                resampleResults <- tmp$resample
              }
            }
          }
          trControl$indexExtra <- NULL
          if (!(trControl$method %in% c("LOOCV", "oob"))) {
            if (modelType == "Classification" && length(grep("^\\cell",
                                                             colnames(resampleResults))) > 0) {
              resampledCM <- resampleResults[, !(names(resampleResults) %in%
                                                   perfNames)]
              resampleResults <- resampleResults[, -grep("^\\cell",
                                                         colnames(resampleResults))]
            }
            else resampledCM <- NULL
          }
          else resampledCM <- NULL
          if (trControl$verboseIter) {
            cat("Aggregating results\n")
            flush.console()
          }
          perfCols <- names(performance)
          perfCols <- perfCols[!(perfCols %in% paramNames)]
          if (all(is.na(performance[, metric]))) {
            cat(paste("Something is wrong; all the", metric,
                      "metric values are missing:\n"))
            print(summary(performance[, perfCols[!grepl("SD$",
                                                        perfCols)], drop = FALSE]))
            stop("Stopping", call. = FALSE)
          }
          if (!is.null(models$sort))
            performance <- models$sort(performance)
          if (any(is.na(performance[, metric])))
            warning("missing values found in aggregated results")
          if (trControl$verboseIter && nrow(performance) > 1) {
            cat("Selecting tuning parameters\n")
            flush.console()
          }
          selectClass <- class(trControl$selectionFunction)[1]
          if (grepl("adapt", trControl$method)) {
            perf_check <- subset(performance, .B == max(performance$.B))
          }
          else perf_check <- performance
          if (selectClass == "function") {
            bestIter <- trControl$selectionFunction(x = perf_check,
                                                    metric = metric, maximize = maximize)
          }
          else {
            if (trControl$selectionFunction == "oneSE") {
              bestIter <- oneSE(perf_check, metric, length(trControl$index),
                                maximize)
            }
            else {
              bestIter <- do.call(trControl$selectionFunction,
                                  list(x = perf_check, metric = metric, maximize = maximize))
            }
          }
          if (is.na(bestIter) || length(bestIter) != 1)
            stop("final tuning parameters could not be determined",
                 call. = FALSE)
          if (grepl("adapt", trControl$method)) {
            best_perf <- perf_check[bestIter, as.character(models$parameters$parameter),
                                    drop = FALSE]
            performance$order <- 1:nrow(performance)
            bestIter <- merge(performance, best_perf)$order
            performance$order <- NULL
          }
          bestTune <- performance[bestIter, paramNames, drop = FALSE]
        }
        else {
          bestTune <- tuneGrid
          performance <- evalSummaryFunction(y, wts = weights,
                                             ctrl = trControl, lev = classLevels, metric = metric,
                                             method = method)
          perfNames <- names(performance)
          performance <- as.data.frame(t(performance))
          performance <- cbind(performance, tuneGrid)
          performance <- performance[-1, , drop = FALSE]
          tmp <- resampledCM <- NULL
        }
        if (!(trControl$method %in% c("LOOCV", "oob", "none"))) {
          byResample <- switch(trControl$returnResamp, none = NULL,
                               all = {
                                 out <- resampleResults
                                 colnames(out) <- gsub("^\\.", "", colnames(out))
                                 out
                               }, final = {
                                 out <- merge(bestTune, resampleResults)
                                 out <- out[, !(names(out) %in% names(tuneGrid)),
                                            drop = FALSE]
                                 out
                               })
        }
        else {
          byResample <- NULL
        }
        orderList <- list()
        for (i in seq(along = paramNames)) orderList[[i]] <- performance[,
                                                                         paramNames[i]]
        performance <- performance[do.call("order", orderList), ]
        if (trControl$verboseIter) {
          bestText <- paste(paste(names(bestTune), "=", format(bestTune,
                                                               digits = 3)), collapse = ", ")
          if (nrow(performance) == 1)
            bestText <- "final model"
          cat("Fitting", bestText, "on full training set\n")
          flush.console()
        }
        indexFinal <- if (is.null(trControl$indexFinal))
          seq(along = y)
        else trControl$indexFinal
        if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds)))
          set.seed(trControl$seeds[[length(trControl$seeds)]][1])
        if (fitFinal) {
          finalTime <- system.time(finalModel <- createModel(x = subset_x(x,
                                                                          indexFinal), y = y[indexFinal], wts = weights[indexFinal],
                                                             method = models, tuneValue = bestTune, obsLevels = classLevels,
                                                             pp = ppOpt, last = TRUE, classProbs = trControl$classProbs,
                                                             sampling = trControl$sampling, ...))
        } else {
          finalModel <- list(fit = NULL, preProc = NULL)
          finalTime <- 0
        }
        if (trControl$trim && !is.null(models$trim)) {
          if (trControl$verboseIter)
            old_size <- object.size(finalModel$fit)
          finalModel$fit <- models$trim(finalModel$fit)
          if (trControl$verboseIter) {
            new_size <- object.size(finalModel$fit)
            reduction <- format(old_size - new_size, units = "Mb")
            if (reduction == "0 Mb")
              reduction <- "< 0 Mb"
            p_reduction <- (unclass(old_size) - unclass(new_size))/unclass(old_size) *
              100
            p_reduction <- if (p_reduction < 1)
              "< 1%"
            else paste0(round(p_reduction, 0), "%")
            cat("Final model footprint reduced by", reduction,
                "or", p_reduction, "\n")
          }
        }
        pp <- finalModel$preProc
        finalModel <- finalModel$fit
        if (method == "pls")
          finalModel$bestIter <- bestTune
        if (method == "glmnet")
          finalModel$lambdaOpt <- bestTune$lambda
        if (trControl$returnData) {
          outData <- if (!is.data.frame(x))
            try(as.data.frame(x), silent = TRUE)
          else x
          if (inherits(outData, "try-error")) {
            warning("The training data could not be converted to a data frame for saving")
            outData <- NULL
          }
          else {
            outData$.outcome <- y
            if (!is.null(weights))
              outData$.weights <- weights
          }
        }
        else outData <- NULL
        if (trControl$savePredictions == "final")
          tmp$predictions <- merge(bestTune, tmp$predictions)
        endTime <- proc.time()
        times <- list(everything = endTime - startTime, final = finalTime)
        out <- structure(list(method = method, modelInfo = models,
                              modelType = modelType, results = performance, pred = tmp$predictions,
                              bestTune = bestTune, call = funcCall, dots = list(...),
                              metric = metric, control = trControl, finalModel = finalModel,
                              preProcess = pp, trainingData = outData, resample = byResample,
                              resampledCM = resampledCM, perfNames = perfNames, maximize = maximize,
                              yLimits = trControl$yLimits, times = times, levels = classLevels),
                         class = "train")
        trControl$yLimits <- NULL
        if (trControl$timingSamps > 0) {
          pData <- x[sample(1:nrow(x), trControl$timingSamps, replace = TRUE),
                     , drop = FALSE]
          out$times$prediction <- system.time(predict(out, pData))
        }
        else out$times$prediction <- rep(NA, 3)
        out
      }
      

      这给了

      data(iris)
      TrainData <- iris[,1:4]
      TrainClasses <- iris[,5]
      
      knnFit1 <- train(TrainData, TrainClasses,
                       method = "knn",
                       preProcess = c("center", "scale"),
                       tuneLength = 10,
                       trControl = trainControl(method = "cv"), fitFinal = FALSE)
      knnFit1$finalModel
      # NULL
      

    【讨论】:

    • indexFinal = NA怎么样
    • @missuse,对!没想到会起作用,就像selectionFunction一样。
    • 将索引设置为 NA 不是首选解决方案,因为它并不总是有效 - 您应该将其从答案中删除。例如,它以简单的lm 失败:train(Sepal.Length ~ ., data=iris, method = "lm", trControl = trainControl(method = "cv", indexFinal = NA))。我现在将使用indexFinal=1,但我更希望完全禁用该配件。谢谢!
    • 实际上,使用NA 会导致我刚刚尝试过的所有(10+)个模型都失败了,除了knn......我猜你很不走运!
    • @antoine-sac,很有趣。我更新了答案。请注意,indexFinal = 1 似乎也并不总是有效;我只是尝试使用支持向量机,即使是 1:50,但仍然不够。
    猜你喜欢
    • 2021-04-14
    • 1970-01-01
    • 2015-07-14
    • 2018-10-29
    • 2018-02-07
    • 1970-01-01
    • 1970-01-01
    • 2021-08-02
    相关资源
    最近更新 更多