【问题标题】:Tensorflow Argmax equivalent for multilabel classification用于多标签分类的 Tensorflow Argmax 等效项
【发布时间】:2019-05-14 18:26:08
【问题描述】:

我想对分类 Tensorflow 模型进行评估。

为了计算准确性,我有以下代码:

predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(labels=label_ids, predictions=logits)

它在单标签分类中效果很好,但现在我想做多标签分类,我的标签是整数数组而不是整数。

这是一个存储在label_ids 中的标签[0, 1, 1, 0, 1, 0] 的示例,以及来自张量logits 的预测[0.1, 0.8, 0.9, 0.1, 0.6, 0.2] 的示例

我应该使用什么函数来代替argmax 这样做? (我的标签是 6 个整数的数组,值为 0 或 1)

如果需要,我们可以假设阈值为 0.5。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    在 tensorflow 之外进行这种类型的后处理评估可能会更好,在这种情况下尝试几个不同的阈值更自然。

    如果你想在tensorflow中做,可以考虑:

    predictions = tf.math.greater(logits, tf.constant(0.5))
    

    对于所有大于 0.5 的条目,这将返回原始 logits 形状的张量,其值为 True。然后,您可以像以前一样计算准确性。这适用于给定样本的多个标签可以同时为真的情况。

    【讨论】:

      【解决方案2】:

      使用以下代码来确定多类分类的准确性:

      tf.argmax 将为y_predy_true(实际y)返回y 值为max 的轴。

      进一步tf.equal用于查找匹配的总数(它返回True,False)。

      将布尔值转换为浮点数(即0或1)并使用tf.reduce_mean计算精度。

      correct_mask = tf.equal(tf.argmax(y_pred,1), tf.argmax(y_true,1))
      accuracy = tf.reduce_mean(tf.cast(correct_mask, tf.float32))
      

      编辑

      数据示例:

      import numpy as np
      
      y_pred = np.array([[0.1,0.5,0.4], [0.2,0.6,0.2], [0.9,0.05,0.05]])
      y_true = np.array([[0,1,0],[0,0,1],[1,0,0]])
      
      correct_mask = tf.equal(tf.argmax(y_pred,1), tf.argmax(y_true,1))
      accuracy = tf.reduce_mean(tf.cast(correct_mask, tf.float32))
      
      with tf.Session() as sess:
        # print(sess.run([correct_mask]))
        print(sess.run([accuracy]))
      

      输出:

      [0.6666667]
      

      【讨论】:

      • 如果我理解得很好,要使此解决方案起作用,y_pred 的值必须为 0 或 1(如 y_true)?就我而言,我有一个浮点值介于 0 和 1 之间的张量。所以我需要在此之前进行转换(带有阈值)?
      • @Nakeuh 不,您可以在y_pred 中包含数值,它将返回具有最大值的轴。
      • 我想做MultiLabel分类。这是一个例子:y_pred[0.1, 0.1, 0.7, 0.8, 0.2, 0.9]y_true[0, 0, 1, 1, 0, 1],在这种情况下,预测是正确的。我不认为我们在谈论同一件事。
      • @Nakeuh 我添加了示例
      猜你喜欢
      • 1970-01-01
      • 2016-06-04
      • 1970-01-01
      • 2017-01-08
      • 2018-06-11
      • 2021-10-04
      • 2017-06-03
      • 2016-05-25
      相关资源
      最近更新 更多