【问题标题】:Confusion matrix missing instances混淆矩阵缺失实例
【发布时间】:2017-04-06 02:33:28
【问题描述】:

我正在使用 PySpark 生成和验证预测。我有一个包含正确列的数据框,我将它传递给 MulticlassMetrics 以获取混淆矩阵。但是当我检查混淆矩阵时,它缺少大部分值(数据框有超过 120.000 行,混淆矩阵只有 8 个左右的值)。为什么会丢失其余部分?

编辑:澄清一下,我不希望混淆矩阵与数据集具有相同的大小,我的数据中有两个类,并且我希望矩阵实例的总和与数字相同我的数据中的行数。问题是我的数据中有大约 120.000 行,混淆矩阵类似于
[[ 0, 3 ], [ 1, 0 ]]

代码:我不能在这里发布整个代码,但这是重要的部分

training_data = load_training_data() # Spark DataFrame
training_data, testing_data = training_data.randomSplit([0.7, 0.3])

asm = VectorAssembler(inputCols=selected_columns, outputCol='features')
final_training_data = asm.transform(training_data)

rf = RandomForestClassifier(labelCol="label", impurity="entropy")
rfModel = rf.fit(final_training_data)

test_predictions = rfModel.transform(testing_data)
predictionAndLabels = test_predictions.select(['prediction', 'label'])

tp = predictionAndLabels.rdd.map(tuple)
metrics = MulticlassMetrics(tp)

【问题讨论】:

  • 我不确定我是否遵循。您是否希望您的混淆矩阵与您的数据集大小相同?
  • 仅供参考,混淆矩阵将是一个方阵,其维度将等于数据中的类数。因此,如果您有 3 个类,则矩阵将为 3x3
  • 我编辑了问题以澄清我的疑问
  • 那么 120k 行是您的测试集吗?否则你是如何拆分数据的?
  • 是的,我随机分配了 30% 用于测试,70% 用于训练。所以 120k 是测试集。

标签: python pyspark confusion-matrix


【解决方案1】:

这里有一个很好的例子来说明如何使用MulticlassMetrics。在此示例中,数据包含 150 个观测值,属于三个类别之一。因此,最终的混淆矩阵是 3x3 的形状,表示为一维的 DenseArray。如果您浏览链接中的示例并在到达 metrics = MulticlassMetrics(predictionAndLabels) 时停下来,您可以执行以下操作来可视化混淆矩阵。

In[6]: metrics = MulticlassMetrics(predictionAndLabels)
In[7]: confusion_mat = metrics.confusionMatrix()
In[8]: print(confusion_mat)
Out[8]: DenseMatrix(3, 3, [15.0, 0.0, 7.0, 0.0, 16.0, 0.0, 1.0, 0.0, 13.0], 0)
In[9]: print(confusion_mat.toArray())
Out[9]: array([[ 15.,   0.,   1.],
               [  0.,  16.,   0.],
               [ 7.,   0.,  13.]])

最终的数组就是你所理解的混淆矩阵。查看 Wikipedia 的 Confusion Matrix 条目以获取更多信息和多类矩阵的一个很好的示例。

如果没有关于您的数据的更多信息,我不能肯定地说,但听起来您有一个 2x2 或 3x3 混淆矩阵,您只需要调用 toArray 以更好地对其进行可视化。

编辑(感谢您添加代码。)

通常,当我运行RandomForestClassifier.transform(test) 时,我最终会得到一个predictedLabel 列,它是预测的实际类别。此外,我认为您不必致电predictionAndLabels.rdd.map(tuple)。在您选择predictedLabeltest_predictions 中的“标签”后,您应该可以直接进入指标。总结一下:

predictionAndLabels = test_predictions.select(['predictedLabel', 'label'])
metrics = MulticlassMetrics(predictionAndLabels)

【讨论】:

  • 谢谢,我去试试,如果成功了再告诉你!
  • 在 Spark 2.4 中,这会在您执行第二行时引发归因错误。 'DataFrame' object has no attribute 'ctx'
猜你喜欢
  • 1970-01-01
  • 2018-06-22
  • 2015-12-17
  • 2020-10-01
  • 2012-01-20
  • 2019-11-23
  • 1970-01-01
  • 1970-01-01
  • 2022-07-07
相关资源
最近更新 更多