【发布时间】:2019-05-14 18:26:08
【问题描述】:
我想对分类 Tensorflow 模型进行评估。
为了计算准确性,我有以下代码:
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(labels=label_ids, predictions=logits)
它在单标签分类中效果很好,但现在我想做多标签分类,我的标签是整数数组而不是整数。
这是一个存储在label_ids 中的标签[0, 1, 1, 0, 1, 0] 的示例,以及来自张量logits 的预测[0.1, 0.8, 0.9, 0.1, 0.6, 0.2] 的示例
我应该使用什么函数来代替argmax 这样做? (我的标签是 6 个整数的数组,值为 0 或 1)
如果需要,我们可以假设阈值为 0.5。
【问题讨论】:
标签: python tensorflow