Meta-Learning论文笔记:Meta Network

Meta Network||论文笔记

本文是对2017年ICML的一篇Meta-Learning论文的笔记论文连接

MetaNet 是Meta Networks的缩写,具有用于跨任务快速泛化的体系结构和训练流程。

名词说明:Fast Weight 和 Slow Weight

模型的跨任务快速概括依赖于fast weight。神经网络中的参数通常是根据目标函数中的梯度下降来更新的,这个过程对于小样本学习是很慢的。一种更快的学习方法是利用一个神经网络预测另一个神经网络的参数,生成的参数称为快权值即fast weight。普通的基于SGD等优化的参数被称为慢权值即slow weight。

在 MetaNet 中,损失梯度信息被作为meta information ,用来生成快权重。在神经网络中,将慢权值和快权值结合起来进行预测。

Meta Network||论文笔记

多层叠加Layer Augmentation

模型:

整体架构:

如图,MetaNet的训练包括三个主要过程: meta information的获取、以及fast weight的生成和slow weight的优化,由base learner和meta learner共同执行。

Meta Network||论文笔记

MetaNet的整体结构

数据集和主要的函数说明:

  • 训练数据包含两种数据集:支持集 Meta Network||论文笔记Meta Network||论文笔记Meta Network||论文笔记和训练集 Meta Network||论文笔记
  • Base learner简写为 b,是一个函数或神经网络。通过任务损失 Meta Network||论文笔记 估计主要任务目标。它的参数由慢权值 example-level的快权值 Meta Network||论文笔记 构成
  • 动态表征函数 u,对样本学习到一个嵌入。参数由慢权值 example-level快权值 Meta Network||论文笔记 组成
  • Meta learner由快速权值生成函数 和 组成,参数为 和 G,它们的输入由损失梯度 Meta Network||论文笔记和 Meta Network||论文笔记 构成,经过映射后生成 Meta Network||论文笔记和 Meta Network||论文笔记和其对应慢权值维度相同

训练过程:

1. 表征函数的学习:将随机采样的支持集数据输入到表征(嵌入)函数 中,为了得到数据集的嵌入,利用表征损失 Meta Network||论文笔记 来捕获表示学习目标,并将梯度作为meta information获取。其中损失函数为:Meta Network||论文笔记它的具体计算是随机抽取 对支持集样本的来计算嵌入损失:Meta Network||论文笔记其中 Meta Network||论文笔记 是辅助标签:Meta Network||论文笔记其实也就是个二分类,属于所有的支持集样本嵌入做距离计算后经过映射或 Meta Network||论文笔记 函数转化为概率,就成为二分类问题。每次任务损失反向传播得到其损失梯度信息:Meta Network||论文笔记对函数 每次任务损失反向传播得到其梯度信息 Meta Network||论文笔记 ,通过快权值生成函数 的映射得到快权值 Meta Network||论文笔记 :Meta Network||论文笔记2. 快权值的生成:对每个支持集样本数据输入到Base learner函数 中,之后计算出预测的标签和支持集实际的标签通过交叉熵等损失函数计算 Meta Network||论文笔记 :Meta Network||论文笔记生成Base learner 的快权值需要支持集的meta information,即利用支持集的损失梯度信息:Meta Network||论文笔记函数 从损失梯度 Meta Network||论文笔记 中学到一个映射,映射后得到快权值 Meta Network||论文笔记 :

Meta Network||论文笔记这个快权值 Meta Network||论文笔记 存储在 Meta Network||论文笔记 中。

3. 建立支持集的索引:利用参数为快权值 Meta Network||论文笔记 和慢权值 Meta Network||论文笔记 的表征函数 支持集进行建立索引(有快权值的支持集的嵌入) Meta Network||论文笔记 :Meta Network||论文笔记4. 建立训练集的索引:与上一步类似,通过具有慢权值和快权值的表征函数 训练集建立查询索引(对训练集的嵌入):Meta Network||论文笔记5. 对快权值的读取:如果参数 Meta Network||论文笔记 存储在 Meta Network||论文笔记 中且索引 Meta Network||论文笔记 已经建立,用attention(这里的attention用余弦相似度计算存储索引和输入索引)在之前建立的所有支持集的索引 Meta Network||论文笔记 和每一个训练集的索引计算一个相似分数:Meta Network||论文笔记然后经过归一化后用于读取存储 Meta Network||论文笔记 得到最终的快权值:Meta Network||论文笔记6. 训练集标签的预测:Base learner函数 有了慢权值 Meta Network||论文笔记 和快权值 Meta Network||论文笔记 后那么执行one-shot分类为:Meta Network||论文笔记这里的 Meta Network||论文笔记 是对 Meta Network||论文笔记 的预测输出,另外这里的输入也可以用训练集的嵌入 Meta Network||论文笔记 代替。最终训练集损失的计算:Meta Network||论文笔记整个网络的训练参数是 Meta Network||论文笔记 ,通过像反向传播算法去最小化任务损失 。

MetaNet的训练算法如图所述:

Meta Network||论文笔记

MetaNet论文中的算法

论文在Omniglot、Mini-ImageNet 和 MNIST 三种数据集上做了One-Shot实验,实验结果都不错,具体可以看一下论文。

总结:该模型利用损失梯度作为元信息来计算快权值,能够快速适应新的不同的任务,增强在训练样本少的情况下的学习效果。效果其实也不是很强,有很多可以改善的点,并且具体训练的时候因为生成快权值的神经网络参数较多或用的LSTM这样的网络所以比较慢。

相关文章: