【问题标题】:Getting precision, recall and F1 score per class in Keras在 Keras 中获取每个班级的准确率、召回率和 F1 分数
【发布时间】:2018-10-08 11:32:56
【问题描述】:

我已经使用 Keras (2.1.5) 中的 TensorFlow 后端训练了一个神经网络,并且我还使用了 keras-contrib (2.0.8) 库来添加一个 CRF 层作为网络的输出。

我想知道在使用 NN 对测试集进行预测后,如何获得每个类的准确率、召回率和 f1 分数。

【问题讨论】:

标签: python tensorflow neural-network deep-learning keras


【解决方案1】:

假设您有一个函数 get_model() 可以构建您已训练的完全相同的模型,并且路径 weights_path 指向包含模型权重的 HDF5 文件:

model = get_model()
model.load_weights(weights_path)

这应该会正确加载您的模型。然后,您只需定义测试数据的ImageDataGenerator 并拟合模型以获得预测:

# Path to your folder testing data
testing_folder = ""
# Image size (set up the image size used for training)
img_size = 256
# Batch size (you should tune it based on your memory)
batch_size = 16

val_datagen = ImageDataGenerator(
    rescale=1. / 255)
validation_generator = val_datagen.flow_from_directory(
    testing_folder,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    shuffle=False,
    class_mode='categorical')

然后,您可以使用model.predict_generator() 方法使模型在整个数据集上生成所有预测:

# Number of steps corresponding to an epoch
steps = 100
predictions = model.predict_generator(validation_generator, steps=steps)

最后使用sklearn 包中的metrics.confusion_matrix() 方法创建一个混淆矩阵:

val_preds = np.argmax(predictions, axis=-1)
val_trues = validation_generator.classes
cm = metrics.confusion_matrix(val_trues, val_preds)

或者使用sklearn 中的metrics.precision_recall_fscore_support() 方法获取所有类的所有精度、召回率和f1 分数(参数average=None 输出所有类的指标):

# label names
labels = validation_generator.class_indices.keys()
precisions, recall, f1_score, _ = metrics.precision_recall_fscore_support(val_trues, val_preds, labels=labels)

我没有测试过,但我想这会对你有所帮助。

【讨论】:

    【解决方案2】:

    看看sklearn.metrics.classification_report:

    from sklearn.metrics import classification_report
    
    y_pred = model.predict(x_test)
    print(classification_report(y_true, y_pred))
    

    给你类似的东西

                 precision    recall  f1-score   support
    
        class 0       0.50      1.00      0.67         1
        class 1       0.00      0.00      0.00         1
        class 2       1.00      0.67      0.80         3
    
    avg / total       0.70      0.60      0.61         5
    

    【讨论】:

      猜你喜欢
      • 2020-06-29
      • 2021-06-24
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-12-28
      • 2019-09-06
      • 2016-04-14
      • 2018-06-02
      相关资源
      最近更新 更多