【问题标题】:How do you create a boolean mask for a tensor in Keras?如何在 Keras 中为张量创建布尔掩码?
【发布时间】:2017-05-18 10:49:49
【问题描述】:

我正在构建一个自定义指标来衡量训练期间我的多类数据集中一个类的准确性。我在选择课程时遇到问题。

目标是一个热点(例如:类0 标签是[1 0 0 0 0]):

from keras import backend as K

def single_class_accuracy(y_true, y_pred):
    idx = bool(y_true[:, 0])              # boolean mask for class 0 
    class_preds = y_pred[idx]
    class_true = y_true[idx]
    class_acc = K.mean(K.equal(K.argmax(class_true, axis=-1), K.argmax(class_preds, axis=-1)))  # multi-class accuracy  
    return class_acc

问题是,我们必须使用 Keras 函数来索引张量。如何为张量创建布尔掩码?

【问题讨论】:

  • 我不熟悉 Keras,不知道您的代码是否可以使用布尔掩码或显式索引。您是否将掩码转换为布尔类型? tf.cast(二进制掩码,tf.bool)。使用 Theano,您可以使用 bool_mask.nonzero() 来获取布尔掩码的索引。让我们知道此解决方案是否有效。
  • 你会接受使用回调的答案吗?
  • 只是为了确保 - y_true 是二维的?这里的行和列应该代表什么?

标签: python tensorflow neural-network keras


【解决方案1】:

请注意,当谈到一类的准确度时,可能指的是以下(不等价的)两个量中的任何一个:

  • recall,对于 C 类,它是标记为 C 类的示例预测为具有类 的比率>C.
  • precision,对于 C 类,是预测为 C 类的示例与实际上标记为类C

您可以只依靠掩码来进行计算,而不是进行复杂的索引。假设我们在这里讨论的是精度(更改为召回将是微不足道的)。

from keras import backend as K

INTERESTING_CLASS_ID = 0  # Choose the class of interest

def single_class_accuracy(y_true, y_pred):
    class_id_true = K.argmax(y_true, axis=-1)
    class_id_preds = K.argmax(y_pred, axis=-1)
    # Replace class_id_preds with class_id_true for recall here
    accuracy_mask = K.cast(K.equal(class_id_preds, INTERESTING_CLASS_ID), 'int32')
    class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
    class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1)
    return class_acc

如果您想更加灵活,还可以对感兴趣的类进行参数化:

from keras import backend as K

def single_class_accuracy(interesting_class_id):
    def fn(y_true, y_pred):
        class_id_true = K.argmax(y_true, axis=-1)
        class_id_preds = K.argmax(y_pred, axis=-1)
        # Replace class_id_preds with class_id_true for recall here
        accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int32')
        class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
        class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1)
        return class_acc
    return fn

并将其用作:

model.compile(..., metrics=[single_class_accuracy(INTERESTING_CLASS_ID)])

【讨论】:

  • Precision 和 Recall 可以结合起来,这个度量称为 F1 分数。它是准确率和召回率的调和平均值,是测试准确度的衡量标准。
  • 尽管 F1 分数(以及准确率和召回率)并未考虑真正的负数,但需要注意这一点很重要。选择合适的指标高度依赖于实际操作。
  • 这里的class_num_true是什么?
  • @Nucl3ic 哎呀,我的错,应该是class_id_true,我想,我已经改了。
  • @jdehesa 非常感谢您的回答,这正是我所需要的。快速澄清。 “在这里用 class_id_preds 替换 class_id_true 以进行召回”
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2021-05-21
  • 1970-01-01
  • 2018-10-02
  • 2021-05-17
  • 2020-07-25
  • 2019-10-31
  • 2018-10-27
相关资源
最近更新 更多