摘要:多标签分类问题中,要想提高预测性能,需要充分利用好标签之间的相互依赖。文献【1】指出可以利用RNN的中间状态来表示标签之间的关系,提出CNN-RNN framework来学习标签之间的依赖关系以及样本-标签之间的关系,文献【2】进一步开发RNN的记忆能力,提出Rethinking Structure学习标签之间的关系,同时使用一种标签敏感的代价函数进行训练。本文对两篇文献的主要思想进行总结。
目录
- CNN-RNN: A Unified Framework for Multi-label Image Classification
- Deep Learning with a Rethinking Structure for Multi-label Classification
主要参考文献
【1】“CNN-RNN: A Unified Framework for Multi-label Image Classification”
【2】“Deep Learning with a Rethinking Structure for Multi-label Classification”
1. CNN-RNN: A Unified Framework for Multi-label Image Classification
摘要:尽管深度卷积神经网络已经在单标签的图像分类任务中取得很大成功,但现实中的图像通常是多标签的。这意味着一个图像会有不同的目标,场景,动作和属性。传统的多标签图像分类方法对每个标签独立学习一个分类器,然后通过排序或者阈值得到最终的分类结果。这些方法尽管也有用,但它们不能挖掘出标签之间的依赖关系。本文中,我们利用RNN来处理这个问题。将其与CNN结合,我们提出的CNN-RNN结构可以学习到语义标签的依赖性和图片-标签的相互关系,同时可以进行端到端的训练。实验中取得了当时最佳的结果。
方法:网络结果如图所示。
- 为了分析标签之间的高阶关系,使用LSTM作为RNN神经元。
- 使用CNN对图像进行特征提取,使用RNN对多标签进行编码。
- 训练时,网络输入为图像和标签,网络输出为标签。
- 预测时,使用束搜索对预测序列进行判断,选择最优解。每次搜索过程中,使用CNN进行特征提取,将特征和当前的标签结合,作为预测层的输入,得到输出序列,然后进行下一步地搜索。由原文公式(4),提取的特征和标签先通过线性变换,相加,再投影到预测层,因此第一次预测的时候不需要标签,之后的每一次预测都用到前一次的预测结果。预测步骤如下图所示。
代码结构:
- 两个编码器:CNN和RNN.
- 一个解码器:Dense.
然后可以像Sequence-to-Sequence那样训练。或者直接构造多输入模型,预测时第一次的标签全部取0。
2. Deep Learning with a Rethinking Structure for Multi-label Classification
摘要:多标签分类是一类重要的机器学习问题。当处理多标签分类问题时,能够处理标签之间隐藏关系的算法能获得更好的性能,而提取标签之间的关系是一项困难的任务。本文中,我们通过RNN的记忆结构提出一种新的深度学习方法来更好的学习标签之间的相互关系。该结构再最终的预测前能很好的考虑不同标签之间的关系。另外,在Rethinking的过程能容易满足不同需求的代价函数。当然也取得了当时最佳的结果。
方法:Rethinking结构如图所示。
- Rethinking结构的实现中,看作者源码首先是将输入向量堆成T个时间步,然后把每个时间的输出都拼接到一起,作为最后全连接层的输入。
- 训练过程中使用标签敏感的代价函数,对常规的二分类代价函数用不同标签值的损失大小进行加权。
代码地址:https://github.com/yangarbiter/multilabel-learn.