【问题标题】:Custom Performance Function in caret Package using predicted Probability插入符号包中使用预测概率的自定义性能函数
【发布时间】:2020-06-30 13:32:49
【问题描述】:

This SO post 是关于使用caret 包中的自定义性能测量功能。您想找到最佳预测模型,因此您构建了多个模型并通过计算从比较观察值和预测值得出的单个指标来比较它们。有默认函数来计算这个指标,但您也可以定义自己的指标函数。此自定义函数必须以 obs 和预测值作为输入。

在分类问题(假设只有两个类)中,预测值为01。但是,我需要评估的也是模型中计算的概率。有什么方法可以实现吗?

原因是在某些应用程序中,您需要知道 1 的预测实际上是 99% 的概率还是 51% 的概率 - 而不仅仅是预测是 1 还是 0。

谁能帮忙?


编辑 好的,所以让我试着解释得更好一点。在 5.5.5 (Alternate Performance Metrics) 下的 caret 包的文档中,描述了如何使用您自己的自定义性能函数,如下所示

fitControl <- trainControl(method = "repeatedcv",
                           number = 10,
                           repeats = 10,
                           ## Estimate class probabilities
                           classProbs = TRUE,
                           ## Evaluate performance using 
                           ## the following function
                           summaryFunction = twoClassSummary)

twoClassSummary 是本例中的自定义性能函数。这里提供的函数需要将带有obspred 的数据帧或矩阵作为输入。这就是重点 - 我想使用一个函数,它不需要观察和预测,而是观察和预测probability


还有一件事:

也欢迎来自其他软件包的解决方案。我唯一不想要的是“这就是你编写自己的交叉验证函数的方式。”

【问题讨论】:

    标签: r r-caret


    【解决方案1】:

    当您在 trainControl 中指定 classProbs = TRUE 时,Caret 确实支持将类概率传递给自定义摘要函数。在这种情况下,创建自定义汇总函数时的data 参数将有另外两列命名为包含每个类概率的类。这些类的名称将在 lev 参数中,这是一个长度为 2 的向量。

    查看示例:

    library(caret)
    library(mlbench)
    data(Sonar)
    

    自定义汇总LogLoss:

    LogLoss <- function (data, lev = NULL, model = NULL){ 
      obs <- data[, "obs"] #truth
      cls <- levels(obs) #find class names
      probs <- data[, cls[2]] #use second class name to extract probs for 2nd clas
      probs <- pmax(pmin(as.numeric(probs), 1 - 1e-15), 1e-15) #bound probability, this line and bellow is just logloss calculation, irrelevant for your question 
      logPreds <- log(probs)        
      log1Preds <- log(1 - probs)
      real <- (as.numeric(data$obs) - 1)
      out <- c(mean(real * logPreds + (1 - real) * log1Preds)) * -1
      names(out) <- c("LogLoss") #important since this is specified in call to train. Output can be a named vector of multiple values. 
      out
    }
    
    fitControl <- trainControl(method = "cv",
                               number = 5,
                               classProbs = TRUE,
                               summaryFunction = LogLoss)
    
    
    fit <-  train(Class ~.,
                 data = Sonar,
                 method = "rpart", 
                 metric = "LogLoss" ,
                 tuneLength = 5,
                 trControl = fitControl,
                 maximize = FALSE) #important, depending on calculated performance measure
    
    fit
    #output
    CART 
    
    208 samples
     60 predictor
      2 classes: 'M', 'R' 
    
    No pre-processing
    Resampling: Cross-Validated (5 fold) 
    Summary of sample sizes: 166, 166, 166, 167, 167 
    Resampling results across tuning parameters:
    
      cp          LogLoss  
      0.00000000  1.1220902
      0.01030928  1.1220902
      0.05154639  1.1017268
      0.06701031  1.0694052
      0.48453608  0.6405134
    
    LogLoss was used to select the optimal model using the smallest value.
    The final value used for the model was cp = 0.4845361.
    

    或者使用包含类级别的lev 参数并定义一些错误检查

    LogLoss <- function (data, lev = NULL, model = NULL){ 
     if (length(lev) > 2) {
            stop(paste("Your outcome has", length(lev), "levels. The LogLoss() function isn't appropriate."))
        }
      obs <- data[, "obs"] #truth
      probs <- data[, lev[2]] #use second class name
      probs <- pmax(pmin(as.numeric(probs), 1 - 1e-15), 1e-15) #bound probability
      logPreds <- log(probs)        
      log1Preds <- log(1 - probs)
      real <- (as.numeric(data$obs) - 1)
      out <- c(mean(real * logPreds + (1 - real) * log1Preds)) * -1
      names(out) <- c("LogLoss")
      out
    }
    

    查看插入符号书的这一部分:https://topepo.github.io/caret/model-training-and-tuning.html#metrics

    了解更多信息。如果您打算使用插入符号,即使您不擅长阅读,也值得一读。

    【讨论】:

    • 好像?只需在您的 R 会话中尝试一下,所有的怀疑都会消失。
    【解决方案2】:

    很遗憾,我刚刚找到了问题的答案。 caret文档里有这么一小句……

    "...如果这些参数都不令人满意,用户还可以计算自定义性能指标。trainControl 函数有一个名为 summaryFunction 的参数,它指定了一个计算性能的函数。该函数应该有这些论点:

    data 是数据框或矩阵的参考,其列称为 obs 和 pred,用于观察和预测的结果值(用于回归的数字数据或用于分类的字符值)。 目前,类概率未传递给函数。数据中的值是单个调整组合的保留预测(及其相关参考值)..."

    为了文档:这是在 2020 年 7 月 3 日写的,caret 包文档来自 2019 年 3 月 27 日。

    【讨论】:

    • 你绑定到caret了吗?较新的tidymodels 试图实现相同的目标,但正在积极开发中。 (包由 caret 的作者 Max Kuhn 推送。)
    • 哦! :) 不,我不受caret 的约束。如果您有任何其他解决方案,请提供答案。 (只是不是,“这是如何编写自己的交叉验证函数”——我可以自己做。;)
    【解决方案3】:

    我不确定我是否正确理解了您的问题:

    要从模型mdl 接收预测概率,您可以使用predict(mdl, type = "prob")。 即,

    library(caret)
    #> Loading required package: lattice
    #> Loading required package: ggplot2
    
    df <- iris
    df$isSetosa <- factor(df$Species == "setosa", levels = c(FALSE, TRUE), labels = c("not-setosa", "is-setosa"))
    df$Species <- NULL
    
    mdl <- train(isSetosa ~ ., data = df, method = "glm",
                    family = "binomial",
                    trControl = trainControl(method = "cv"))
    
    preds <- predict(mdl, newdata = df, type = "prob")
    head(preds)
    #>     not-setosa is-setosa
    #> 1 2.220446e-16         1
    #> 2 2.220446e-16         1
    #> 3 2.220446e-16         1
    #> 4 1.875722e-12         1
    #> 5 2.220446e-16         1
    #> 6 2.220446e-16         1
    
    

    reprex package (v0.3.0) 于 2020 年 7 月 2 日创建

    即,我们看到案例 4 被预测为具有 ~100% 的 setosa(tbh,这个玩具模型好得令人难以置信)...

    现在我们可以创建一个自定义函数,将值折叠为单个指标。

    true <- df$isSetosa
    
    # very basic model metrics that just sums the absolute differences in true - probability
    custom_model_metric <- function(preds, true) {
      d <- data.frame(true = true)
      tt <- predict(dummyVars(~true, d), d)
      colnames(tt) <- c("not-setosa", "is-setosa")
      
      sum(abs(tt - preds))
    }
    
    custom_model_metric(preds, true)
    #> [1] 3.294029e-09
    

    reprex package (v0.3.0) 于 2020 年 7 月 2 日创建

    【讨论】:

    • 嘿!不幸的是,这不是我想要的。我将编辑问题。已经感谢您的努力。
    猜你喜欢
    • 2014-07-27
    • 2020-01-25
    • 2016-08-04
    • 1970-01-01
    • 2013-08-26
    • 1970-01-01
    • 2019-09-03
    • 2016-10-06
    • 2015-05-14
    相关资源
    最近更新 更多