一、相关工作
NIC代码分析:https://blog.csdn.net/zlrai5895/article/details/82667804
本文部分内容来自:https://blog.csdn.net/shenxiaolu1984/article/details/51493673
二、 基本思想
文章在NIC的基础上加入了attention机制
三、模型结构
对LSTM部分做出的改动,其余与NIC相同。
四、代码分析
(0)预处理
首先是把数据中长度大于20的caption删除,这是第一次筛选。然后建立词汇库(代码中的大小为5000), 对数据再次筛选,只保留所有的单词都在词汇库中的句子。建立数据集,一共361258个图像和caption对。数据集包含几个部分:
self.image_ids = np.array(image_ids) #每一幅图像的id
self.image_files = np.array(image_files) #每一幅图像对应的路径
self.word_idxs = np.array(word_idxs) #(361258,20)每个caption中的单词用单词对应的id代替词汇库中单词的下标索引(包括标点符符号和<start>)
self.masks = np.array(masks) #(361258,20) 用来记录长度,有词是1,没词是0
self.batch_size = batch_size #定义读取数据时候的batch_size
(1) 首先是CNN(VGG网络)提取特征,最后得到的特征图是(batch_size,16,16,512),16*16代表了原本图像196个区域,每个区域用512维的特征来表示。rshape成(batch_size,196,512)
reshaped_conv5_3_feats = tf.reshape(conv5_3_feats,[config.batch_size, 196, 512])
(2)建立 embedding_matrix
大小为(5000,512),也就是说词汇库里的5000个单词,每个单词用512维的向量来表示 并且做初始化(使用预训练好的word_embedding)
(3)建立RNN
不再直接使用图像特征a,而是对不同的区域加上不同的权重,得到上下文z(context)。
首先,图像特征作为最初的context.,使用两个全连接层得到最初的memory(c0)和out(o0),作为LSTM最初的state。
context_mean = tf.reduce_mean(self.conv_feats, axis = 1) #图像特征作为最初的context (batch_size,512)
initial_memory, initial_output = self.initialize(context_mean)#使用两个全连接层得到最初的memory(c)和out(o)
initial_state = initial_memory, initial_output #最初的输入state
输入的caption默认为(batch_size,max_length),这里max_length取20。不到20的后面补0,并且用masks做了标记.
对于每个时刻的单词,首先引入attention,加入权重。并得到加权后的context和masks。
αt维度为L=196L=196,记录释义aa每个像素位置获得的关注。
权重αt可以由前一步系统隐变量htht经过若干全连接层获得。编码et用于存储前一步的信息。灰色表示模块中有需要优化的参数。
“看哪儿”不单和实际图像有关,还受之前看到东西的影响。
第一步权重完全由图像特征aa决定:
alpha = self.attend(contexts, last_output) #引入注意力机制,加入权重 (batch_size,196)对196个区域的权重
context = tf.reduce_sum(contexts*tf.expand_dims(alpha, 2),
axis = 1) #加权之后的context (batch_size,512)
if self.is_train:
tiled_masks = tf.tile(tf.expand_dims(masks[:, idx], 1),
[1, self.num_ctx]) #(batch_size,196) masks[:, idx] 全部批次某个时刻的mask
masked_alpha = alpha * tiled_masks #得到加权后的结果 如果maskd对应的是0 权重也就变成了0
alphas.append(tf.reshape(masked_alpha, [-1])) #masked_alpha: (batch_size,196)
githubs上tensorflow 版本的代码 attend部分的具体实现略有不同,这里就不再给出细节。
把当前时刻的权重存入列表。
alphas.append(tf.reshape(masked_alpha, [-1])) #masked_alpha: (batch_size,196)
把word_embedding和加权之后的context连接起来,作为当前时刻的输入,得到out_put和state.
current_input = tf.concat([context, word_embed], 1) #当前时刻的输入是 加权后context 和word_embeeding的结合 (bacth_size,1024)
output, state = lstm(current_input, last_state) #(batch_size,512)
memory, _ = state #其他show and tell一样 (bacth_size,512) (batch_size,512)
利用得到的输出和加权的context计算下一个单词的概率。做出预测
logits = self.decode(expanded_output) #(bacth_size,5000)
probs = tf.nn.softmax(logits)
prediction = tf.argmax(logits, 1)
最后,为下个时刻提供上个时刻的输出和state等。
last_output = output
last_memory = memory
last_state = state
last_word = sentences[:, idx] #开始下一个单词