开放域对话预训练模型总结,主要包括 PLATO,DialoGPT,Meena,Blender,PLATO-2

进行训练数据、模型架构、预训练任务、模型解码、结果评价等方面的对比

开放域对话预训练模型总结

相关资源

论文、官方博客新闻以及GitHub代码

Model Paper & Blog GitHub
DialoGPT DIALOGPT : Large-Scale Generative Pre-training for Conversational Response Generation
Microsoft Blog
microsoft/DialoGPT
yangjianxin1/GPT2-chitchat
Meena Towards a Human-like Open-Domain Chatbot
Google AI Blog
google-research/meena
Blender Recipes for building an open-domain chatbot
ParlAI Docs
facebookresearch/ParlAI
PLATO PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable
Baidu AI
PaddlePaddle/Research
PLATO-2 PLATO-2: Towards Building an Open-Domain Chatbot via Curriculum Learning
Baidu AI
PaddlePaddle/Knover

训练数据

大规模类似对话数据预训练

Model Dataset Size Dataset Source
DialoGPT 147M samples(1.8B words) Reddit
Meena 341GB text(40B words) Reddit
Blender 1.5B samples(88.8B words) Reddit
PLATO 8.3M samples Twitter/Reddit
PLATO-2 684M(EN)/1.2B(ZH) samples Reddit/Twitter (Chinese unk)
  • sample 构造成 (context, response) 形式
  • 均使用 Byte-Level BPE tokenizations
  • PLATO-2 中文数据(来自 GitHub issue
    • 考虑到不同的来源上文谈话的内容风格和话题差异比较大,为了保证训练语料的多样性和覆盖面,PLATO-2训练使用的中文数据来源是很丰富的(并不局限在一个数据集上),包括多个公开数据集,脱敏的社交媒体数据,以及人工标注数据

模型架构

图片来自T5论文

开放域对话预训练模型总结

Model Architecture Parameters
BERT Transformer Encoder 110/340M (Chinese 110M)
DialoGPT Language Model (GPT-2) 117/345/762M
Meena Seq2Seq Transformer (ET) 2.6B
Blender PolyEncoder + Seq2Seq Transformer 90M/2.7B/9.4B
PLATO Prefix LM (UniLM) 110M (132M)
PLATO-2 Prefix LM (UniLM) 310M/1.6B (Chinese 333M)
  • Retrieve and Refine Model
    • Blender 中使用检索优化模型,在生成之前组合检索步骤,包括 Dialogue Retrieval 和 Knowledge Retrieval
      • Dialogue retrieval,使用 PolyEncoder 检索回复,将检索回复拼接到输入之前
        • 直接对话检索回复拼接到对话历史后面效果不好,需要随机替换为gold response,强迫生成模型从检索结果中 copy 词或者词组
      • Knowledge retrieval,使用 Wizard Generative model
        • 分别利用当前对话 topic(对话 topic 会事先告知)和最后两轮对话,各自检索出 top-K 的文章
        • 把所有文章各自分句,然后把各自文章的 title 追加到每个句子最前面,获得很多候选句子
        • 再利用 poly-encoder 结构的模型对候选句子排序,最终使用 top-1 的句子作为检索结果
        • 同时还会训练一个单独的分类器来判断是否需要从知识库中检索知识。回复某些对话 context 可以不需要额外知识,这时候就不用追加检索结果
  • UniLM 可以在 BERT 基础上改变 Mask 矩阵实现(上文使用双向 Attention,回复使用单向 Attention)
    • 论文中经过各项任务的反复验证,UniLM 在同等规模参数量的情况下具有最佳的性价比
    • PLATO 直接使用 BERT预训练模型参数初始化(然后再使用对话数据训练)

训练任务

Response Generation 为主要训练任务

Model Pre-training Obejctive Loss Function Cost
DialoGPT Language modeling (+MMI) Cross-entropy Loss 16 V100 -
Meena Response generation Cross-entropy Loss TPU v3 Pod 30 days
Blender Ranking for retrieval (+MLM)
Likelihood training for generation
Unlikelihood training for generation
Cross-entropy Loss MLE -
PLATO Response generation
Latent act recognition
Response selection
Negative log likelihood loss (NLL)
Bag-Of-Words (BOW)
Loss Binary Cross-entropy Loss
8 V100 32G 2 weeks
PLATO-2 Ibid.
Response selection (+MLM)
NLL/BOW+RCE/MLM 64 V100 3 weeks
  • DialoGPT

    • Language Modeling
      p(TS)=n=m+1Np(xnx1,,xn1) p(T|S) = \prod_{n=m+1}^N p (x_n | x_1, \cdots, x_{n-1})
  • Blender

    • Ranking for Retrieval,在检索中使用同一 batch 内其他回复作为反例

    • Likelihood Training for Generation

      • Given a dataset D={(x(i),y(i))}\mathcal{D} = \{(\mathbf{x}^{(i)}, \mathbf{y}^{(i)})\}, minimize

      LMLE(i)(pθ,x(i),y(i))=t=1y(i)logpθ(yt(i)x(i),y<t(i)) \mathcal{L}_{MLE}^{(i)} (p_{\theta}, \mathbf{x}^{(i)}, \mathbf{y}^{(i)}) = - \sum_{t=1}^{|y^{(i)}|} {\log p_{\theta}(y_t^{(i)} | \mathbf{x}^{(i)}, y_{<t}^{(i)})}

      • x(i)\mathbf{x}^{(i)} 是输入对话历史,y(i)\mathbf{y}^{(i)} 是对话回复
    • α\alpha-blending for retrieve and refine

      • 当添加检索结果后,MLE 的生成往往会忽视检索的结果
      • 训练阶段,随机将 α\alpha 比例的正确对话回复代替检索结果拼接(实现在检索和生成之间的平稳过渡)
    • Unlikelihood Training for Generation

      • 提升正确token概率同时,降低其他token 的概率(惩罚项);在选择负token时,选择容易组成常见n-gram的tokens;期望降低生成无意义回复的比例

      • Unlikelihood Loss 在每步 tt 会惩罚一部分 token,记为 Ct\mathcal{C}_t
        LUL(i)(pθ,C1:T,x,y)=t=1yycCtlog(1pθ(ycx,y<t)) \mathcal{L}_{UL}^{(i)}(p_{\theta}, \mathcal{C}_{1:T}, \mathbf{x}, \mathbf{y}) = - \sum_{t=1}^{|y|} \sum_{y_c \in \mathcal{C}_{t}} \log (1 - p_{\theta}(y_c | \mathbf{x}, y_{<t}))

      • 最终 Loss 为 MLE 和 UL 混合,混合比例为超参数
        LULE(i)=LMLE(i)+αLUL(i) \mathcal{L}_{ULE}^{(i)} = \mathcal{L}_{MLE}^{(i)} + \alpha \mathcal{L}_{UL}^{(i)}

  • PLATO

    • Response SelectionLatent act Recognition 同时进行,使用双向注意力编码;负例从数据集中随机选择
      LRS=logp(lr=1c,r)logp(lr_=0c,r_)p(lr=1c,r)=sigmoid(W3h[M]+b3) \mathcal{L}_{RS} = - \log p(l_r = 1 | c, r) - \log p(l_{r^\_}=0|c, r^\_) \\ p(l_r=1|c,r) = sigmoid(W_3 h_{[M]} + b_3)

      • 其中 h[M]RDh_{[M] \in \mathbb{R}^D} 是特殊 mask 标记的最后隐层状态表示(离散隐变量识别的实现中使用 Gumbel-Softmax,避免采样后梯度消失)
    • response generation 使用单向注意力解码,计算 NLL lossBOW loss

      • NLL loss
        LNLL=Ezp(zc,r)logp(rc,z)=Ezp(zc,r)t=1Tlogp(rtc,z,r<t) \mathcal{L}_{NLL} = - \mathbb{E}_{z \sim p(\mathbf{z}|c, r)} \log p(r| c, z) \\ = - \mathbb{E}_{z \sim p(\mathbf{z}|c, r)} \sum_{t=1}^T \log p(r_t| c, z, r_{<t}) \\

        • 其中 zz(c,r)(c,r) 的隐变量取值,通过概率分布p(zc,r)p(\mathbf{z}|c,r)获得

        p(zc,r)=softmax(W1h[M]+b1)RK p(\mathbf{z}|c,r) = softmax(W_1 h_{[M]} + b_1) \in \mathbb{R}^K

      • BOW loss(和NLL相比,BOW打乱了语序,促使离散隐变量学习目标回复的局部信息)
        LBOW=Ezp(zc,r)t=1Tlogp(rtc,z)=Ezp(zc,r)t=1TlogefrtvVefv \mathcal{L}_{BOW} = - \mathbb{E}_{z \sim p(\mathbf{z}|c,r)} \sum_{t=1}^T \log p(r_t|c,z) \\ = - \mathbb{E}_{z \sim p(\mathbf{z}|c,r)} \sum_{t=1}^T \log \frac{e^{f_{r_t}}}{\sum_{v \in V} e^{f_v}}

        • 其中 VV是整个词表,ff 是以非自回归的方式预测目标回复中的单词(hzh_z是离散隐变量的隐层表示)
          f=softmax(W2hz+b2)RV f = softmax(W_2 h_z + b_2) \in \mathbb{R}^{|V|}
    • 最终优化目标
      L=LNLL+LBOW+LRS \mathcal{L}= \mathcal{L}_{NLL} + \mathcal{L}_{BOW} + \mathcal{L}_{RS}

  • PLATO-2

    • 采取了课程学习的方法,逐步优化参数,加快训练效率
    • PLATO-2 采用了 GPT-2 的前置正则化层的方式,以更好适应大规模训练的需求
    • 训练步骤
      • **第一步,PLATO-2 先训练了不含隐变量的模型。**该模型进行的是简化的“一对一”建模,容易生成安全回复
      • **第二步,在前一步模型基础上,添加上隐变量,然后同时训练 Generation + Recognition 和 Response Selection 两个不同模型。**其中,Response Selection 模型在合适度预测的基础上,还添加了 Masked Language Model 作为辅助任务,以强化模型对语义的理解。

模型解码

几乎都在解码中尝试各种方式来提高回复生成的多样性

各种解码方法参考 huggingface blog

Model Method
DialoGPT MMITop-K sampling
Meena Sample-and-rank,Top-K-sampling(Sampling is better than Beam Search)
Blender Beam Search(Length Controlling + Subsequence blocking)
PLATO Beam Search(Response Selection)
PLATO-2 Beam Search(Response Coherence Estimation)
  • Maximum mutual information(MMI)

    • 使用预训练的后向模型来预测从给定回复 target 到对话历史 source 的生成概率 i.e., P(Sourcetarget)P(Source|target)
    • MMI 倾向于惩罚枯燥乏味的回复,频繁和重复的回复会和许多上文关联,因此对于某些特殊的回复会返回较低的分数
    • 论文中使用 345M 的 GPT-2 medium model 作为 Backward model
  • Sample-and-rank

    强大的模型生成回复的 ppl 足够小,可以牺牲 ppl 来生成多样性回复

    1. sample N independent candidate responses using plain random sampling with temperature T
    2. select the candidate response with the highest probability to use as the final output(Hinton et al., 2015)

    给定当前所有单词的概率,T越大,概率分布越平缓,概率差距缩小,容易随机取到其他的单词

    pi=exp(zi/T)jexp(zi/T) p_i = \frac{exp(z_i/T)}{\sum_j exp(z_i/T)}

  • Beam Search(Length Controlling,Subsequence Blocking)

    Beam Search改进方法较多,如果模型本身不够好,起到的作用很小(与 sampling 效果差不多)

    • 控制生成回复的最小长度
      • Minimum Length,设置最小长度阈值(20效果比不设置好)
      • Predictive Length,额外训练回复长度分类器(<10/20/30,>30, poly-encoder),增大模型复杂度
    • 屏蔽重复的子序列(Subsequence Blocking),不允许产生当前句子和前面对话(context)中已经存在的 3-grams(不显著)

效果评估

Fine-tune

进行效果测试的下游任务

Model Dataset
DialoGPT DSTC-7 Dialogue Generation Challenge
Reddit
Meena -
Blender ConvAI2
Wizard of Wikipedia(WoW)
Empathetic Dialogues
Blended SKill Talk
PLATO Persona-Chat
Daily Dialog
DSTC7-AVSD
PLATO-2 -

评价方式

Model Static Evaluation Dynamic Evaluation 3-Level Pairwise Evaluation Desc
DialoGPT ✔︎ ✔︎ ✔︎ fine-tuning
Meena ✔︎ ✔︎ ✔︎ human-bot chat SSA
PPL
Blender ✔︎ ✔︎ ✔︎ fine-tuning,self-chat
PLATO ✔︎ ✔︎ fine-tuning
PLATO-2 ✔︎ ✔︎ ✔︎ ✔︎ self-chat,actual-eval
  • Static Evaluation
    • 输入预设的对话历史,要求模型预测回复(fine-tune);自动评价+人工评价
  • Dynamic Evaluation
    • 和模型进行动态对话交互,根据机器人的每轮回复进行评分(一般打分区间为 {1,2,3})
    • (PLATO-2英文采用 Self-Chat方式,给定启动句,人工评估多轮对话效果;中文采用 Human-Bot Chat)
  • Pairwise Evaluation
    • Human-Bot Chat 耗费人力、物力,采用机器人之间聊天的方式
    • 提供2组(machine-human/machine-machine)对话日志,(针对每组中的特定机器人)选择更有吸引力和接近人类的对话(ACTUAL-Eval,Self-Chat)
  • Meena 中提出 使用新的人工评价指标SSA(sensibleness and specificity avarage);包括合理性(常识正确性、逻辑连贯性和一致性)和特定性(针对对话历史);论文中证明 PPL 和 SSA 存在高度负相关性

未来改进

Model Model
DialoGPT -
Meena SSA指标还可以加入其他的人类对话属性
包括幽默,同理心,深入的推理,问题解答和知识讨论技巧
Blender 倾向于使用高频词
倾向于生成重复信息(copy)
内容冲突和遗忘
无法深度对话(外部知识的使用)
无法深度理解
更长的对话仍然存在问题(低成本评测?)
PLATO -
PLATO-2 探索强化学习应用

相关文章:

  • 2022-01-16
  • 2021-06-23
  • 2021-11-19
  • 2021-06-30
  • 2022-12-23
  • 2021-10-31
  • 2021-10-16
猜你喜欢
  • 2021-11-05
  • 2021-09-13
  • 2022-01-23
  • 2022-12-23
  • 2021-09-19
  • 2021-08-14
  • 2022-02-01
相关资源
相似解决方案