双壁合一

卷积神经网络(CNNS)

CNNs 易于并行化,却不适合捕捉变长序列内的依赖关系。

循环神经网络(RNNS)

RNNs 适合捕捉长距离变长序列的依赖,但是自身的recurrent特性却难以实现并行化处理序列。

整合CNN和RNN的优势,Vaswani et al., 2017 创新性地使用注意力机制设计了 Transformer 模型。


该模型利用 attention 机制实现了并行化捕捉序列依赖,并且同时处理序列的每个位置的 tokens ,上述优势使得 Transformer 模型在性能优异的同时大大减少了训练时间。


如图展示了 Transformer 模型的架构,与机器翻译及其相关技术介绍中介绍的 seq2seq 相似,Transformer同样基于编码器-解码器架构,其区别主要在于以下三点:

  1. Transformer blocks:seq2seq循环网络_{seq2seq模型}–> Transformer Blocks

    Transform Blocks模块包含一个多头注意力层(Multi-head Attention Layers)以及两个 position-wise feed-forward networks(FFN)。对于解码器来说,另一个多头注意力层被用于接受编码器的隐藏状态。
  2. Add and norm:多头注意力层和前馈网络的输出被送到两个“add and norm”层进行处理
    该层包含残差结构以及层归一化
  3. Position encoding:由于自注意力层并没有区分元素的顺序,所以一个位置编码层被用于向序列元素里添加位置信息。

Transformer (Google 机器模型)

Transformer. Transformer 架构.

鉴于新子块第一次出现,在此前 CNNS 和 RNNS 的基础上,实现 Transform 子模块,并且就机器翻译及其相关技术介绍中的英法翻译数据集实现一个新的机器翻译模型。

import os
import math
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('path to file storge d2lzh1981')
import d2l

masked softmax

参考Seq2seq1注意力机制和Seq2seq模型_{工具1}

def SequenceMask(X, X_len,value=-1e6):
    maxlen = X.size(1)
    X_len = X_len.to(X.device)
    #print(X.size(),torch.arange((maxlen),dtype=torch.float)[None, :],'\n',X_len[:, None] )
    mask = torch.arange((maxlen), dtype=torch.float, device=X.device)
    mask = mask[None, :] < X_len[:, None]
    #print(mask)
    X[~mask]=value
    return X

def masked_softmax(X, valid_length):
    # X: 3-D tensor, valid_length: 1-D or 2-D tensor
    softmax = nn.Softmax(dim=-1)
    if valid_length is None:
        return softmax(X)
    else:
        shape = X.shape
        if valid_length.dim() == 1:
            try:
                valid_length = torch.FloatTensor(valid_length.numpy().repeat(shape[1], axis=0))#[2,2,3,3]
            except:
                valid_length = torch.FloatTensor(valid_length.cpu().numpy().repeat(shape[1], axis=0))#[2,2,3,3]
        else:
            valid_length = valid_length.reshape((-1,))
        # fill masked elements with a large negative, whose exp is 0
        X = SequenceMask(X.reshape((-1, shape[-1])), valid_length)
 
        return softmax(X).reshape(shape)

# Save to the d2l package.
class DotProductAttention(nn.Module): 
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # query: (batch_size, #queries, d)
    # key: (batch_size, #kv_pairs, d)
    # value: (batch_size, #kv_pairs, dim_v)
    # valid_length: either (batch_size, ) or (batch_size, xx)
    def forward(self, query, key, value, valid_length=None):
        d = query.shape[-1]
        # set transpose_b=True to swap the last two dimensions of key
        scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_length))
        return torch.bmm(attention_weights, value)

多头注意力层


引入:自注意力(self-attention)

自注意力模型是一个正规的注意力模型,序列的每一个元素对应的key,value,query是完全一致的。

与循环神经网络相比,自注意力对每个元素输出的计算是并行的,所以我们可以高效的实现这个模块。

Transformer (Google 机器模型)

自注意力结构

输出了一个与输入长度相同的表征序列


多头注意力层包含hh并行的自注意力层,每一个这种层被成为一个head。

对每个头来说,在进行注意力计算之前,我们会将query、key和value用三个现行层进行映射,这hh个注意力头的输出将会被拼接之后输入最后一个线性层进行整合。

Transformer (Google 机器模型)

多头注意力

假设query,key和value的维度分别是dqd_qdkd_kdvd_v。那么对于每一个头i=1,,hi=1,\ldots,h,我们可以训练相应的模型权重Wq(i)Rpq×dqW_q^{(i)} \in \mathbb{R}^{p_q\times d_q}Wk(i)Rpk×dkW_k^{(i)} \in \mathbb{R}^{p_k\times d_k}Wv(i)Rpv×dvW_v^{(i)} \in \mathbb{R}^{p_v\times d_v},以得到每个头的输出:

o(i)=attention(Wq(i)q,Wk(i)k,Wv(i)v) o^{(i)} = attention(W_q^{(i)}q, W_k^{(i)}k, W_v^{(i)}v)

这里的attention可以是任意的attention function,之后我们将所有head对应的输出拼接起来,送入最后一个线性层进行整合,这个层的权重可以表示为WoRd0×hpvW_o\in \mathbb{R}^{d_0 \times hp_v}

o=Wo[o(1),,o(h)] o = W_o[o^{(1)}, \ldots, o^{(h)}]

接下来实现多头注意力,假设有h个头,隐藏层权重 hidden_size=pq=pk=pvhidden\_size = p_q = p_k = p_v 与query,key,value的维度一致。除此之外,因为多头注意力层保持输入与输出张量的维度不变,所以输出feature 的维度也设置为 d0=hidden_sized_0 = hidden\_size

MultiHeadAttention class

相关文章:

  • 2021-12-25
  • 2021-07-01
  • 2021-04-19
  • 2021-11-23
  • 2021-10-21
  • 2022-01-14
猜你喜欢
  • 2021-07-21
  • 2021-09-30
  • 2022-01-17
  • 2022-02-16
  • 2021-05-31
  • 2021-11-17
相关资源
相似解决方案