【发布时间】:2021-12-19 12:07:06
【问题描述】:
我目前正在使用 TensorFlow 解决多标签分类问题(总共 9 个标签),这是模型编译行:
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
y_true 我的模型的标签由 2 个 1 和 7 个 0 组成(例如,[0,1,0,0,0,1,0,0 ,0])。
我用TensorFlow尝试了几个模型,但无论模型多么复杂,准确率一直很差,准确率在0.3左右。
我想知道 Keras 准确度指标是否也适用于多标签分类。例如,如果 y_pred 的概率值为 [0.1, 0.9, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0.1],Keras 是否会选择顶部y_pred中的2个概率,将它们转换为[0, 1, 0, 0, 1, 0, 0, 0, 0]的二进制标签,然后将准确率与y_true 标签的那些?
如果没有,我是否必须实现自己的指标功能?
提前致谢!
【问题讨论】:
-
不要使用 Keras,但我认为预测的编码是这样的,如果单个标签概率超过 0.5,那么它被编码为 1,否则为零。例如,如果使用多元线性回归,您可以将代码欺骗为 y=(-1,1,-1,-1,-1,1,-1,-1,-1) 的真值标签,然后任何正预测被编码为 1,负预测被编码为零,例如 y=(-0.3,0.6,-0.4,-0.2,-0.3,0.4,-0.2,-0.5,-0.4)。点是,零是阈值。在您的情况下,0.5 可能是阈值。
-
另一点是对于长度为 9 的二进制字符串,有 2^9=512 个可能的 1 和 0 组合。因此,您需要确保广泛覆盖结果空间(类标签)来训练您的模型。性能不佳可能是由于数据和所选模型造成的。也许尝试使用 MSE 来解决错误而不是交叉熵。 MSE 将等于 0.5 * [(0-0.1)^2 + (1-0.9)^2 + ... + (0-0.1)^2]。我很少将交叉熵用于 ANN,而是经常使用 MSE 进行分类和函数逼近。
-
也许您应该尝试将损失函数更改为“categorical_crossentropy”,因为您有一个多类问题。
标签: tensorflow keras