在之前,我们介绍了两种可解释性神经网络:
这里我们再介绍另一种网络结构:GAMI-Net,对应的文章为
- GAMI-Net: An Explainable Neural Network based on Generalized Additive Models with Structured Interactions
- Accurate Intelligible Models with Pairwise Interactions
前文回顾
前面两篇文章,做可解释性神经网络用的都是GAIM,但由于模型的复杂性依旧相对较高,中的的可解释性依旧比较弱,因此后面模型考虑使用GAM模型。
GAIM:
GAM:
Accurate Intelligible Models with Pairwise Interactions
这篇文章是将交互项比较早引入到GAM中的,我们先学习一下此文章。
首先先回顾一般的GAM模型:
添加交互项的GAM模型:
目标是最小化损失函数:
目前的问题在于对于高维的情形,添加的交互项过多,从而极大地影响算法的运行速度。若维度为,则交互项个数为。因此文章后面提出了一种贪婪前项分步筛选算法进行交互项的选取。针对交互项的选取,主要思路为先列出所有的交互项,而后一项一项的添加进入候选交互项集合,每一轮选出一个误差最小的,最后直到不再影响模型预测精度后,停止迭代。而选取的依据则是根据目前模型的残差,用一个交互项的函数去拟合现有的残差。
添加交互项的算法主体如下:
现在问题的关键在于如何选取这样一个交互项的函数,文章采用了几种方法进行实验:
第一种最简单的想法是用两个变量,记为一个pair找两个阈值,将二维区域分为四个子区域,如上图右侧节点所示。而后类似回归树,每个区域r的预测值都是用那个区域所有样本label的平均值进行预测。依靠下述的残差平方和RSS,我们可以选取每个pair的最优分割,记为。
由此,文章比较好的解决了交互项函数的选取,但算法复杂度依旧非常高,因为四个区域每个区域都需要计算一次平均值,而两个变量针对所有样本取得的值都需要遍历一次,复杂度非常高。由此文章提出了一种降低复杂度的计算方法。
如下图,四个区域我们分别记为,注意到四个区域,针对两个变量,我们可以通过遍历两个变量的所有样本可能取值,直接计算出的值,而后只需要计算出每个样本对应的值,就可以简单计算出对应的值,而不需要再重新遍历。
更进一步,我们可以类似地更加细分区域,如同树的方法,对分割的两个区域,再分别按照设置不同的阈值,对四个区域进行分割。
文章为了更进一步降低计算复杂度,针对样本量非常大的连续形变量,等距分为256个区域,转化为256个不同的值,再从中找寻最优分割,实验表明其对最终结果的精度几乎没有影响,并且能够极大降低算法复杂度。
最终文章的算法流程为:
- 第一步建立GAM模型,考虑所有的一维变量;
- 第二部逐步添加交互项,确立最终的模型。
真实数据
真实数据实验,针对回归数据集:
针对分类数据集:
从上述十个数据集的比较来看,本文提出的方法GA2M FAST在一系列基于GAM的方法中表现最好(GA2M Rand,GA2M Coef,GA2M Order分别为随机加交互项或者按照一个固定的条件添加交互项)。但实际上准确率与RF对比,发现RF总体来看表现得更好。
最后文章应用了上述提出的方法,对真实数据建立模型,并且最后论证选出来的交互项是非常有意义的交互项。
GAMI-Net: An Explainable Neural Network based on Generalized Additive Models with Structured Interactions
下面回到文章:《GAMI-Net: An Explainable Neural Network based on Generalized Additive Models with Structured Interactions》。
文章提出的网络结构:GAMI,全名为generalized additive models with structured interactions,其实也就是添加交互项的网络结构。本文从叙述中感觉和前面的方法与结构比较相似,下面是网络主要结构,主要考虑了多层子网络结构,Heredity以及添加了Marginal clarity约束。
Marginal clarity:
此约束为了保证模型的可识别性,以及为了防止交互项影响过大从而到时结果不够稳定。下面来叙述整个方法的流程。
此方法的流程和上篇文章的方法比较一致,也是先筛选Main Effects,而后再添加进交互项进一步进行loss的优化,最后还有一步来prune interactions,但是具体的方法可能由于是预印版,没有详细地说明,之后会看开源的代码继续进行学习。而对于具体的交互项选择,则也是使用类似回归树的分割方法。最后文章也对比了一些方法,在识别准确率上有进一步提升。