【问题标题】:Using caret to optimize for deviance with binary classification使用插入符号优化二元分类的偏差
【发布时间】:2023-03-31 21:35:02
【问题描述】:

(从Fatal error with train() in caret on Windows 7, R 3.0.2, caret 6.0-21借来的例子)

我有这个例子:

library("AppliedPredictiveModeling")
library("caret")

data("AlzheimerDisease")
data <- data.frame(predictors, diagnosis)

tuneGrid <- expand.grid(interaction.depth = 1:2, n.trees = 100, shrinkage = 0.1)
trainControl <- trainControl(method = "cv", number = 5, verboseIter = TRUE)

gbmFit <- train(diagnosis ~ ., data = data, method = "gbm", trControl = trainControl, tuneGrid = tuneGrid)

但是,假设我想优化偏差(我认为 gbm 默认返回)而不是准确性。我知道 trainControl 提供了一个 summaryFunction 参数。如何编写一个针对偏差进行优化的 summaryFunction?

【问题讨论】:

    标签: r r-caret gbm


    【解决方案1】:

    偏差只是(减去)对数似然的两倍。对于单次试验的二项式数据,即:

    -2 \sum_{i=1}^n y_i log(\pi_i) + (1 - y_i)*log(1-\pi_i)
    

    y_i 是第一类的二进制指标,\pi 是第一类的概率。

    这是一个在 GLM 中重现偏差的简单示例(通过重新计算训练集偏差):

    > library(caret)
    > set.seed(1)
    > dat <-twoClassSim(200)
    > fit1 <- glm(Class ~ ., data = dat, family = binomial)
    > ## glm() models the last class level
    > prob_class1 <- 1 - predict(fit1, dat[, -ncol(dat)], type = "response")
    > is_class1 <- ifelse(dat$Class == "Class1", 1, 0)
    > -2*sum(is_class1*log(prob_class1) + ((1-is_class1)*log(1-prob_class1)))
    [1] 112.7706
    > fit1
    
    Call:  glm(formula = Class ~ ., family = binomial, data = dat)
    <snip>  
    Degrees of Freedom: 199 Total (i.e. Null);  184 Residual
    Null Deviance:      275.3 
    Residual Deviance: 112.8    AIC: 144.8
    

    train 的基本功能是:

    dev_summary <- function(data, lev = NULL, model = NULL) {
      is_class1 <- ifelse(data$obs == lev[1], 1, 0)
      prob_class1 <- data[, lev[1]]
    
      c(deviance = -2*sum(is_class1*log(prob_class1) + 
                            ((1-is_class1)*log(1-prob_class1))),
        twoClassSummary(data, lev = lev))
    }
    
    ctrl <- trainControl(summaryFunction = dev_summary,
                         classProbs = TRUE)
    gbm_grid <- expand.grid(interaction.depth = seq(1, 7, by = 2),
                            n.trees = seq(100, 1000, by = 50),
                            shrinkage = c(0.01, 0.1))
    set.seed(1)
    fit2 <- train(Class ~ ., data = dat,
                  method = "gbm",
                  trControl = ctrl,
                  tuneGrid = gbm_grid,
                  metric = "deviance",
                  verbose = FALSE)
    

    请注意,如果\pi 非常接近零或一,您将需要考虑做些什么。

    最大

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-03-31
      • 1970-01-01
      • 2019-10-07
      • 2013-03-06
      • 2017-01-01
      相关资源
      最近更新 更多