图网络(GN)在深度学习短板即因果推理上拥有巨大潜力,很有可能成为机器学习领域的下一个增长点,而图神经网络(GNN)正属于图网络的子集。GNN近期在图分类任务上得到了当前最佳的结果,但其存在平面化的局限,因而不能将图分层表征。现实应用中,很多图信息都是层级表征的,例如地图、概念图、流程图等,捕获层级信息将能更加完整高效地表征图,应用价值很高。在本文中,来自斯坦福等大学的研究者通过在GNN中结合一种类似CNN中空间池化的操作——可微池化,实现了图的分层表征DIFFPOOL在深度GNN的每一层针对节点学习可微分的软簇分配,将节点映射到一组簇中去,然后这些簇作为粗化输入,输入到GNN下一层。
介绍
近年来人们开发图神经网络的兴趣持续激增。图神经网络,即可以在比如社交网络数据或分子结构数据等图结构数据上运行的通用深度学习架构。GNN一般是将底层图作为计算图,通过在图上传递、转换和聚合节点特征信息学习神经网络基元以生成单个节点嵌入。生成的节点嵌入可以作为输入,用于如节点分类或连接预测的任何可微预测层,完整的模型可以通过端到端的方式训练。
然而,现有的GNN结构的主要限制在于太过平坦,因为它们仅通过图的边传播信息,无法以分层的方式推断和聚合信息。例如,为了成功编码有机分子的图结构,就要编码局部分子结构(如单个的原子和与这些原子直接相连的键)和分子图的粗粒结构(如在分子中表示功能单元的原子基团和键)。对图分类任务而言缺少分层结构尤其成问题,因为这类任务是要预测与整个图相关的标签。在图分类任务中应用 GNN,标准的方法是针对图中所有的节点生成嵌入,然后对这些节点嵌入进行全局池化,如简单地求和或在数据集上运行神经网络。这种全局池化方法忽略了可能出现在图中的层级结构,也阻碍了研究人员针对完整图的预测任务建立有效的GNN模型。
研究者在此提出了DIFFPOOL,这是一个可以分层和端到端的方式应用于不同图神经网络的可微图池化模块。DIFFPOOL允许开发可以学习在图的层级表征上运行的更深度的GNN模型。他们开发了一个和CNN中的空间池化操作相似的变体,空间池化可以让深度CNN在一张表征越来越粗糙的图上迭代运行。与标准CNN相比,GNN的挑战在于图不包含空间局部性的自然概念,也就是说,不能将所有节点简单地以的方式池化在一张图上,因为图复杂的拓扑结构排除了任何直接、决定性的的定义。此外,与图像数据不同,图数据集中包含的图形节点数和边数都不同,这使得定义通用的图池化操作更具挑战性。
为了解决上述问题,我们需要一个可以学习如何聚合节点以在底层图上建立多层级架构的模型。DIFFPOOL在深度GNN的每一层学习了可微分的软分配,这种软分配是基于学习到的嵌入,将节点映射为一组聚类。以该方法为框架,作者通过分层的方式「堆叠」了 GNN 层建立了深度 GNN:GNN 模块中层的输入节点对应GNN模块中层学到的聚类簇。因此,DIFFPOOL的每一层都能使图越来越粗糙,然后训练后的DIFFPOOL就可以产生任何输入图的层级表征。本研究展示了DIFFPOOL可以结合到不同的GNN方法中,这使准确率平均提高了7%,并且在五个基准图分类任务中,有四个达到了当前最佳水平。最后,DIFFPOOL可以学习到与输入图中明确定义的集合相对应的可解释的层级簇。
最近工作
作者的工作是基于最近GNN和图分类的研究。
GNN的最近工作。最近几年提出了许多GNN模型,有受CNN启发的方法,还有RNN、递归网络和循环置信传播。大部分方法都属于Gilmer提出的neural message passing框架,在这种观点下GNN是一种message passing算法,节点特征和邻居关系通过GNN而迭代计算出节点表示。Hamilton总结了这个领域的最近进展,Bronstein概述了图卷积的联系。
GNN实现的图分类。GNN应用于许多任务如节点分类、链接预测、图分类和信息化学。在图分类背景下所面临的一个问题是如何更好的通过GNN产生节点嵌入以表示出整个网络的特征。一般的方法有,在最后一层对节点嵌入进行简单求和或平均、引入与所有节点连接的虚拟节点和用深度学习聚合节点嵌入。然而这些方法都有一个限制,不能学习层级表示(所有节点在一个单层进行全局池化),所以不能捕获现实世界的自然结构。最近有一些方法,用CNN将所有节点嵌入串联起来,但是这需要节点的拓扑排序。
最后,最近也有一些工作,将GNN和确定性聚类方法结合起来以学习层级图表示。与这个不同的是,作者的方法是在端对端训练的框架下自动学习层级结构表示,而不是依靠确定性聚类方法。
提出方法
DIFFPOOL的关键想法是在多层GNN结构中引入节点的可微层级池化。这一节概述DIFFPOOL模块以及如何在GNN中应用。
图表示为,其中是邻接矩阵,是节点特征矩阵,是每个节点的特征维数。给定一个带标签的图数据,其中是图的类标签,任务目标是寻找映射。 相对于标准监督学习过程,这里的困难主要在于如何更好的从输入的图中提取特征,为了应用深度学习等机器学习方法进行分类,我们需要将每个图转换成一个有限维向量。
GNN。在这个工作中,作者以端到端训练的方式使用GNN学习提取用于图分类的特征。GNN使用message passing结构,其中是GNN迭代次后的节点嵌入,是由邻接矩阵和参数决定的message传播函数,是上一步message passing产生的节点嵌入。输入节点嵌入初始化为节点输入特征,。
传播函数有多种实现方式。有一种流行的GNN变种GCN,M的实现方式是将线性变换和ReLU非线性**结合起来
其中是需要训练的权重矩阵。作者提出的可微分池化层能应用到任意GNN模型中,不论以何种方式实现。GNN迭代次产生最终节点嵌入,其中一般是之间,以下论述中忽略GNN的内部结构,并简单记为。
GNN和池化层的堆叠。GNN的实现内部是平面化的,信息只能通过边传播。作者的目标是提出一个通用的、端对端可微分的方法,将GNN模块堆叠为层级结构。给定原始图的邻接矩阵后可以产生GNN的输出,之后给出一个粗化的图,粗化图的节点数为,邻接矩阵为,节点嵌入为。这个粗化图作为下一层GNN的输入,经过次重复产生越来越粗化的图,并分别由串联的GNN进行处理。于是我们的目标是学习如何使用上一层GNN的输出结果对节点进行聚类或池化,再把聚类或池化所输出的粗化图作为下一层GNN的输入。设计GNN的池化层是比较困难的,相比于一般的粗化图任务,不是在一个图上对节点进行聚类,而是在图集合上进行层级池化,在推理时要对许多不同的图结构进行自适应池化。
基于分配学习的可微分池化。上述提到的DIFFPOOL,难点在于使用GNN的输出学习节点分配的聚类,将L个GNN堆叠起来,可微池化层利用上一个GNN产生的节点嵌入进行节点聚类,从而产生粗化图,并以端对端方式进行训练学习。于是GNN产生的节点嵌入,既用于图分类,又用于层级池化,而这通过大量的图样本进行训练学习。以下先描述,DIFFPOOL在有了节点分配矩阵后具体如何聚类池化,再描述在GNN架构下如何产生分配矩阵。
用分配矩阵进行池化。将层的聚类分配矩阵记为。的每一行代表在层中的个节点中的一个节点(或一个节点聚类簇),每一列代表层中的节点聚类簇,提供了从层的图节点到层的图节点(聚类簇)的软分配。
现在我们已经有了层的节点分配矩阵,将这层图的邻接矩阵记为,层图的节点嵌入记为。DIFFPOOL可微池化层在此基础上产生输入图的粗化图,,是下一层粗化图的邻接矩阵,是下一层粗化图的节点(聚类簇)输入特征。
公式是根据分配矩阵,将上一层的节点嵌入转换成下一层的节点(聚类簇)嵌入,类似的,公式是将上一层的邻接矩阵转换成下一层粗化图的邻接矩阵。是一个层的全连接实值矩阵,而代表层聚类簇和聚类簇之间的连接强度。类似的,的第行代表第个聚类簇的输入特征。最后,和作为下一层GNN的输入特征。
学习并产生分配矩阵。现在描述DIFFPOOL如何产生分配矩阵。我们使用两个独立的GNN(嵌入GNN和池化GNN)产生两个矩阵,层中的嵌入GNN为
将层邻接矩阵和输入特征作为一个标准GNN的输入,进而产生一个新的嵌入。相比之下,池化GNN则使用和产生分配矩阵
其中应用于输出矩阵的每一行。注意到这两个GNN使用相同的输入数据,但是具有不同的参数和作用:嵌入GNN对输入特征产生节点嵌入,池化GNN对节点产生概率分配,从而对应粗化图的聚类簇。
在层时,公式的输入就是原始图的邻接矩阵和节点特征,而倒数第二层的分配矩阵是全为1的向量,这样就能在最后一层将所有节点归到一个聚类簇,并最后产生一个代表整个图的嵌入向量。最后的嵌入表示作为一个可微分类器的输入特征,整个系统以端对端的形式进行随机梯度下降训练。
置换不变性。为了更好分类,池化层应具有置换不变性。对于DIFFPOOL,作者表明只要GNN组件满足置换不变性,那么整体就会满足了。
令为置换矩阵,若,则
链接预测及正则化。在实际中,仅从图分类的梯度信号中训练池化GNN是困难的,这是一个非凸优化问题。为了缓解这个问题,在训练时在加上链接预测目标,这促使邻近节点一起池化。在每层中,最小化,其中是范数。
另一个池化GNN的重要特征是,节点的聚类簇分配向量应当接近向量,这样能使聚类簇更加清晰明显。此外还要正则化聚类簇分配的熵,为此加入最小化,其中为熵,为的第行。在训练时,每层的和都添加到分类损失一起。在实际中,作者发现这样训练会收敛的更加缓慢,然而效果更好,并且聚类簇的分配也更有解释性。
实验
作者为了评估DIFFPOOL的优势,将DIFFPOOL与最优秀的图分类方法相比,并回答下列问题:
Q1:与其他已提出的GNN池化方法相比,DIFFPOOL如何?
Q2:与现有最好的图分类任务模型相比,结合了DIFFPOOL的GNN如何?
Q3:DIFFPOOL对输入图给出了有意义且可解释的簇吗?
数据集。使用多种图分类基准数据,如蛋白质数据集(ENZYMES,PROTEINS,D&D),社交网络数据集(REDDIT-MULTI-12K),科研合作数据集(COLLAB)。的数据作为验证集,剩下作为训练数据并以折验证方式评估模型结果。
论文地址