【问题标题】:SparkR 2.0 Classification: how to get performance matrices?SparkR 2.0 分类:如何获得性能矩阵?
【发布时间】:2017-07-30 14:25:07
【问题描述】:

如何获取 sparkR 分类中的性能矩阵,例如 F1 分数、Precision、Recall、Confusion Matrix

# Load training data
df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm")
training <- df
 testing <- df

# Fit a random forest classification model with spark.randomForest
model <- spark.randomForest(training, label ~ features, "classification", numTrees = 10)

# Model summary
  summary(model)

 # Prediction
  predictions <- predict(model, testing)
  head(predictions)

 # Performance evaluation 

我试过caret::confusionMatrix(testing$label,testing$prediction)它显示错误:

   Error in unique.default(x, nmax = nmax) :   unique() applies only to vectors

【问题讨论】:

    标签: apache-spark machine-learning apache-spark-sql spark-dataframe sparkr


    【解决方案1】:

    Caret 的 confusionMatrix 不起作用,因为它需要 R 数据帧,而您的数据位于 Spark 数据帧中。

    一种推荐的获取指标的方法是使用as.data.frame 将您的Spark 数据帧本地“收集”到R,然后使用caret 等;但这意味着您的数据可以放入驱动程序机器的主内存中,在这种情况下,您当然完全没有理由使用 Spark...

    所以,这里有一种以分布式方式(即不在本地收集数据)获取准确性的方法,以iris 数据为例:

    sparkR.version()
    # "2.1.1"
    
    df <- as.DataFrame(iris)
    model <- spark.randomForest(df, Species ~ ., "classification", numTrees = 10)
    predictions <- predict(model, df)
    summary(predictions)
    # SparkDataFrame[summary:string, Sepal_Length:string, Sepal_Width:string, Petal_Length:string, Petal_Width:string, Species:string, prediction:string]
    
    createOrReplaceTempView(predictions, "predictions")
    correct <- sql("SELECT prediction, Species FROM predictions WHERE prediction=Species")
    count(correct)
    # 149
    acc = count(correct)/count(predictions)
    acc
    # 0.9933333
    

    (关于 150 个样本中的 149 个正确预测,如果您执行 showDF(predictions, numRows=150),您会看到确实有一个 virginica 样本被错误分类为 versicolor)。

    【讨论】:

      猜你喜欢
      • 2023-04-01
      • 2019-05-08
      • 1970-01-01
      • 2023-01-12
      • 2014-02-08
      • 2020-11-18
      • 2020-02-26
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多