下载地址:https://arxiv.org/pdf/1912.03761v1.pdf
code:https://github.com/NYUMedML/GNN_for_EHR
摘要
基于电子病历和图神经网络预测AD。
本文根据真实的EHR能够提前12-24个月预测AD,并且在稀疏电子病历中改善了预测性能。另外,本文通过模型学习每个图结构,进一步探索了不同诊断、实验室值以及电子病历流程之间的结构关系。
引言:
对于非网格数据,利用图神经网络提取数据特征对于卷积神经网络具有很好的泛化性能。Electronic Health Records (EHR) 具有稀疏性,而且很大概率具有缺失值。因为CNN能够从缺失像素值的图像中提取特征,那么GNN也能类似的根据其他实体的表示推理出缺失实体,从而得到一个解释性和泛化性能更好的表示。
基于电子健康病历预测AD的两大挑战:
大多数EHR是稀疏的,缺少很多疾病记录或者实验评估数据
其次,在EHR中所有变量类型之间的关系图结构很难构建,因为它需要大量的精力和专业知识才能为超过40万个节点手动构建合理的图结构,并且现有的本体仅捕获孤立的变量组之间的一小部分关系 。
本文利用Graph Attention Network 解决上述两大挑战,并根据EHR预测未确诊为AD的病人未来12-24月的AD诊断结果。
相关工作
GCN信号处理
GAT:图注意力网络(GAT)在图中的每个节点及其相邻节点上及其自身上学习局部特征,而不是使用频谱过滤器。 在这种架构中,GAT可以为边缘分配不同的重要性,从而增加了模型的容量和可解释性,并通过注意力权重学习了图结构本身。
本文根据已有的研究,提出创建EHR图网络,把每一个EHR当作一个节点,最初在这些节点上强加一个全连接结构,然后通过GAT中的自注意力机制隐式地学习EHR之间的图结构。 在EHR数据上引入了一个2层GAT,该模型可根据从隐藏变量到原始EHR图中的节点的连接以及相应的节点表示所计算出的隐藏变量来预测结果节点的标签。 如果结果不止一个,那么这将使我们的结果既可以同时关注图中的其他节点,也可以相互关注。
方法论
1. Graph Attention Layer
对于每一层图注意层,都有一个由边和节点组成的杂技节图,对于每个节点都有自己的特征矩阵,本文想要结合图信息对于每个节点生成新的表示集。
为了得到设定的输出维度并且扩大模型容量,调整特征矩阵h的维度d->d’,将线性层W应用到每个输入特征,然后根据注意力机制计算节点i 和节点j之间的权重 aij,计算公式为:
$alpha$=0.15
多头注意力机制就像在一个CNN中有多个卷积核,能够联合获得不同位置不同表示子空间的信息,所以本文运用k-head注意力机制。对于第k head,节点i的输出图表示是通过其邻居的加权总和及其注意力权重上的特征得出的。
2.Predicting Outcomes with Two-layer Graph
Embedding Layer
给定用于诊断,过程,实验室结果(合并并转换为二进制变量)和人口统计信息的one-hot 编码矢量,将这些变量转换为高维嵌入矢量。将含有医学信息的embedding训练为每个节点的特征值。
Input Graph Layer
对于每个病人的EHR,x1 x2 x3…xN是它的正向特征,把这些特征看作图中的观察节点以及这些节点间全连接的边。通过embedding lookup表查找到这些节点的特征表示h1 h2 …hN。最终通过K head 注意力机制可以得到k个图表示特征。
Output Graph Layer
当输入节点为x1 x2 x3…xN时,输出层将会输出生成的y1 y2 …yN。同样的也会将输入的每个节点进行全连接,因为输出节点y并不包含在输入图中,所以将输出表示 表示为已有embedding的concatenation. 然后将图注意力层应用到所有节点x1 x2 …xN y1 y2 …yM的全连接图上,得到k个图表示输出。
为了得到最终的输出,我们将多头注意力结果进行求均值,并且应用线性转换将 输出结果转换为分类输出。
通过两层图注意力之后,可以得到最终的概率分布表示。
损失函数:
实验
预测病人在未来12-24个月的AD诊断结果,所以只有一个诊断结果的节点。
数据
训练:
训练过程中随机mask 10%的数据。为了减少epoch且从正向数据中学习到足够的信息,对于正向的数据进行上采样50次。同样在每次训练中针对小于50岁的负样本进行80%的下采样来加快训练速度。
超参设置:
Model Performance
评估方式: ROC curve and precision-recall curve
模型对比结果
可视化
利用t-SNE算法,对AD病人的embedding特征以及图特征进行可视化。
具有相似性的医学概念聚集在一起
讨论
CBOW还是演示图结构功能的基准。通过将CBOW与基于图的模型(图表示,GAT)进行比较,我们了解到图网络比单个嵌入层具有更强的特征提取能力。同样,通过图结构,通过可视化节点和连接它们的边的位置,可以解释神经网络,如4.4节所示。关于EHR数据的稀疏性,应该注意的是,更频繁的特征对应于更一般的条件,并且规格较高的特征比其他特征更稀疏。这意味着最积极的特征是不包含强信号的一般特征。图网络通过将注意力集中在更重要的特征上来提高模型的性能。