【问题标题】:average of confusion matrix in RR中混淆矩阵的平均值
【发布时间】:2018-08-23 11:02:11
【问题描述】:

我应用了 10 次交叉验证,输出是混淆矩阵的 10 倍,那么如何通过混淆矩阵找到倍数的平均值?

我的工作是否正确?

这是我的代码:

set.seed(100)
    library(caTools)
    library(caret)
    library(e1071)
    folds<-createFolds(wpdc$outcome, k=10) 
    CV <- lapply(folds, function(x){
      traing_folds=wpdc[-x,]
      test_folds=wpdc[x,]
      dataset_model_nb<-naiveBayes(outcome ~ ., data = traing_folds)
      dataset_predict_nB<-predict(dataset_model_nb, test_folds[-1])
      dataset_table_nB<-table(test_folds[,1],dataset_predict_nB)
      accuracy<-confusionMatrix(dataset_table_nB, positive ="R")
      return(accuracy)
    })

 outcome radius_mean texture_mean perimeter_mean area_mean smoothness_mean compactness_mean concavity_mean concave_points_mean symmetry_mean fractal_dimension_mean radius_se texture_se perimeter_se area_se smoothness_se
1       N       18.02        27.60         117.50    1013.0         0.09489           0.1036         0.1086             0.07055        0.1865                0.06333    0.6249     1.8900        3.972   71.55      0.004433
2       N       17.99        10.38         122.80    1001.0         0.11840           0.2776         0.3001             0.14710        0.2419                0.07871    1.0950     0.9053        8.589  153.40      0.006399
3       N       21.37        17.44         137.50    1373.0         0.08836           0.1189         0.1255             0.08180        0.2333                0.06010    0.5854     0.6105        3.928   82.15      0.006167

【问题讨论】:

  • 您是否要计算平均 cv 准确率? “混淆矩阵的平均折叠”是什么意思?请提供样本数据集。
  • 是的,输出是 10 个混淆矩阵,所以我需要平均值(一个混淆矩阵)我的意思是我需要以混淆矩阵的方式计算平均值,希望你能理解我,谢谢跨度>
  • 1.让你的函数只返回matrix,即return(accuracy$table)。 2. 使用ReduceCV 中的矩阵求和,即。 Reduce('+', CV)
  • 非常有帮助,谢谢

标签: r data-mining naivebayes


【解决方案1】:

我需要同样的,然后按照@Stephen Handerson 的提示,我就是:

  1. 定义矩阵列表:
    • rfConfusionMatrices &lt;- list()
  2. 将每个矩阵存储在该列表中:
    • RrfConfusionMatrix[[i]] &lt;- confMatrix
  3. 使用Reduce 函数对矩阵求和并除以折叠:
    • rfConfusionMatrixMean &lt;- Reduce('+', rfConfusionMatrix) / nFolds

【讨论】:

    【解决方案2】:

    如果您重新组织代码并将预测和真实标签存储为:

    set.seed(100)
        library(caTools)
        library(caret)
        library(e1071)
        folds <- createFolds(wpdc$outcome, k=10) 
        CV <- lapply(folds, function(x){
          traing_folds=wpdc[-x,]
          test_folds=wpdc[x,]
          dataset_model_nb<-naiveBayes(outcome ~ ., data = traing_folds)
          dataset_predict_nB<-predict(dataset_model_nb, test_folds[-1])
          dataset_table_nB<-table(test_folds[,1],dataset_predict_nB)
          return(dataset_table_nB) # storing true and predicted values
        })
    

    您可以通过减少来附加它们:

    appended_table_nB<- do.call(rbind, dataset_table_nB)
    

    然后取混淆矩阵:

    accuracy <- confusionMatrix(appended_table_nB, positive ="R")
    

    这与取平均值相同。唯一的区别是您将 conf 矩阵中的数据点相加,但准确度和其他指标是它们的平均值。如果您想查看 conf 矩阵的平均值,您可以:

    averaged_matrix &lt;- as.matrix(accuracy) / nFold

    【讨论】:

      【解决方案3】:

      我刚刚在 Google 上搜索,以了解从混淆矩阵中计算平均值是否很常见。以防万一有人对可以调整以保存的不仅仅是平均值的解决方案感兴趣:

      我定义了以下函数来从混淆矩阵或类似对象的list 中获取均值和标准差,因为所有这些矩阵都具有相同的格式:

      average_matr <- function(matr_list){
        if(class(matr_list[[1]])[1] == "confusionMatrix"){
          matr_lst <- lapply(matr_list, FUN = function(x){x$table})
        }else{
          matr_lst <- matr_list
        }
        vals <- lapply(matr_lst, as.numeric)
        matr <- do.call(cbind, vals)
        #vec_mean <- apply(matr, MARGIN = 1, FUN = mean, na.rm = TRUE)
        vec_mean <- rowMeans(matr, na.rm = TRUE)
        matr_mean <- matrix(vec_mean, nrow = nrow(matr_lst[[1]]))
        vec_sd <- apply(matr, MARGIN = 1, FUN = sd, na.rm = TRUE)
        matr_sd <- matrix(vec_sd, nrow = nrow(matr_lst[[1]]))
        out <- list(matr_mean, matr_sd)
        return(out)
      }
      
      average_matr(confusion_matr)
      

      如果列表中的对象属于confusionMatrix 类,则该函数将仅提取值。如果是矩阵列表,会计算均值和标准差。

      请注意,rowMeans 应该比带有FUN = meanapply 快,但是,据我所知,没有sd 功能。虽然我使用了类似的语法,但 applymean 可以被替换,但对于较小的数据集应该没有明显的区别。

      编辑:添加了两个版本。

      附加:包括导出为 LaTeX 表格

      average_matr <- function(matr_list, latex_file = NA,
                               metric = "sd", return = TRUE){
        if(class(matr_list[[1]])[1] == "confusionMatrix"){
          matr_lst <- lapply(matr_list, FUN = function(x){x$table})
        }else{
          matr_lst <- matr_list
        }
        vals <- lapply(matr_lst, as.numeric)
        matr <- do.call(cbind, vals)
        #vec_mean <- apply(matr, MARGIN = 1, FUN = mean, na.rm = TRUE)
        vec_mean <- rowMeans(matr, na.rm = TRUE)
        matr_mean <- matrix(vec_mean, nrow = nrow(matr_lst[[1]]))
        if(metric == "sd"){
          vec_sd <- apply(matr, MARGIN = 1, FUN = sd, na.rm = TRUE)
        }else if(metric == "se"){
          vec_sd <- apply(matr, MARGIN = 1,
                          FUN = function(x){sd(x, na.rm = TRUE)/sqrt(length(x))})
        }else{
          vec_sd <- NA
        }
        if(length(vec_sd) > 1){
          matr_sd <- matrix(vec_sd, nrow = nrow(matr_lst[[1]]))
          out <- list(matr_mean, matr_sd)
        }else{
          out <- matr_mean
        }
        # generate latex table
        if(is.character(latex_file)){
          if(dir.exists(dirname(latex_file))){
            sink(latex_file)
            cat("\\hline\n")
            cat(paste(row.names(matr_lst[[1]]), collapse = " & "), "\\\\\n")
            cat("\\hline\n")
            if(length(vec_sd) > 1){
              for(r in 1:nrow(matr_mean)){
                cat(paste(formatC(matr_mean[r, ], digits = 1, format = "f"),
                          formatC(matr_sd[r, ], digits = 1, format = "f"),
                          sep = " \\(\\pm\\) ", collapse = " & "), "\\\\\n")
              }
            }else{
              for(r in 1:nrow(matr_mean)){
                cat(paste(formatC(matr_mean, digits = 1, format = "f"),
                           collapse = " & "), "\\\\\n")
              }
            }
            cat("\\hline\n")
            sink()
          }else{
            warning("Directory not found: ", latex_file)
          }
        }
        if(return){
          return(out)
        }
      }
      

      【讨论】:

        猜你喜欢
        • 2016-08-13
        • 1970-01-01
        • 1970-01-01
        • 2018-01-25
        • 2022-01-05
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多