在两年半之前作过梯度提升树(GBDT)原理小结,但是对GBDT的算法库XGBoost没有单独拿出来分析。虽然XGBoost是GBDT的一种高效实现,但是里面也加入了很多独有的思路和方法,值得单独讲一讲。因此讨论的时候,我会重点分析和GBDT不同的地方。

    本文主要参考了XGBoost的论文和陈天奇的PPT

    作为GBDT的高效实现,XGBoost是一个上限特别高的算法,因此在算法竞赛中比较受欢迎。简单来说,对比原算法GBDT,XGBoost主要从下面三个方面做了优化:

    一是算法本身的优化:在算法的弱学习器模型选择上,对比GBDT只支持决策树,还可以直接很多其他的弱学习器。在算法的损失函数上,除了本身的损失,还加上了正则化部分。在算法的优化方式上,GBDT的损失函数只对误差部分做负梯度(一阶泰勒)展开,而XGBoost损失函数对误差部分做二阶泰勒展开,更加准确。算法本身的优化是我们后面讨论的重点。

    二是算法运行效率的优化:对每个弱学习器,比如决策树建立的过程做并行选择,找到合适的子树分裂特征和特征值。在并行选择之前,先对所有的特征的值进行排序分组,方便前面说的并行选择。对分组的特征,选择合适的分组大小,使用CPU缓存进行读取加速。将各个分组保存到多个硬盘以提高IO速度。

    三是算法健壮性的优化:对于缺失值的特征,通过枚举所有缺失值在当前节点是进入左子树还是右子树来决定缺失值的处理方式。算法本身加入了L1和L2正则化项,可以防止过拟合,泛化能力更强。

    在上面三方面的优化中,第一部分算法本身的优化是重点也是难点。现在我们就来看看算法本身的优化内容。

2. XGBoost损失函数

    在看XGBoost本身的优化内容前,我们先回顾下GBDT的回归算法迭代的流程,详细算法流程见梯度提升树(GBDT)原理小结第三节,对于GBDT的第t颗决策树,主要是走下面4步:

    1)对样本i=1,2,...m,计算负梯度

rti=−[∂L(yi,f(xi)))∂f(xi)]f(x)=ft−1(x)

 

    2)利用Rtj,j=1,2,...,J。其中J为回归树t的叶子节点的个数。

    3) 对叶子区域j =1,2,..J,计算最佳拟合值

ctj=argmin⏟c∑xi∈RtjL(yi,ft−1(xi)+c)

 

    4) 更新强学习器

ft(x)=ft−1(x)+∑j=1JctjI(x∈Rtj)

 

    上面第一步是得到负梯度,或者是泰勒展开式的一阶导数。第二步是第一个优化求解,即基于残差拟合一颗CART回归树,得到J个叶子节点区域。第三步是第二个优化求解,在第二步优化求解的结果上,对每个节点区域再做一次线性搜索,得到每个叶子节点区域的最优取值。最终得到当前轮的强学习器。

    从上面可以看出,我们要求解这个问题,需要求解当前决策树最优的所有J个叶子节点区域和每个叶子节点区域的最优解ctj。GBDT采样的方法是分两步走,先求出最优的所有J个叶子节点区域,再求出每个叶子节点区域的最优解。

    对于XGBoost,它期望把第2步和第3步合并在一起做,即一次求解出决策树最优的所有J个叶子节点区域和每个叶子节点区域的最优解ctj。在讨论如何求解前,我们先看看XGBoost的损失函数的形式。

    在GBDT损失函数L(y,ft−1(x)+ht(x))的基础上,我们加入正则化项如下:

Ω(ht)=γJ+λ2∑j=1Jwtj2

 

    这里的w表示叶子区域的值,因此这里和论文保持一致。

    最终XGBoost的损失函数可以表达为:

Lt=∑i=1mL(yi,ft−1(xi)+ht(xi))+γJ+λ2∑j=1Jwtj2

 

     最终我们要极小化上面这个损失函数,得到第t个决策树最优的所有J个叶子节点区域和每个叶子节点区域的最优解wtj。XGBoost没有和GBDT一样去拟合泰勒展开式的一阶导数,而是期望直接基于损失函数的二阶泰勒展开式来求解。现在我们来看看这个损失函数的二阶泰勒展开式:

 

(1)Lt=∑i=1mL(yi,ft−1(xi)+ht(xi))+γJ+λ2∑j=1Jwtj2(2)≈∑i=1m(L(yi,ft−1(xi))+∂L(yi,ft−1(xi)∂ft−1(xi)ht(xi)+12∂2L(yi,ft−1(xi)∂ft−12(xi)ht2(xi))+γJ+λ2∑j=1Jwtj2

    

 

    为了方便,我们把第i个样本在第t个弱学习器的一阶和二阶导数分别记为

gti=∂L(yi,ft−1(xi)∂ft−1(xi),hti=∂2L(yi,ft−1(xi)∂ft−12(xi)

 

    则我们的损失函数现在可以表达为:

Lt≈∑i=1m(L(yi,ft−1(xi))+gtiht(xi)+12htiht2(xi))+γJ+λ2∑j=1Jwtj2

 

    损失函数里面wtj,因此我们的损失函数可以继续化简。

 

j

相关文章:

  • 2022-12-23
  • 2021-12-04
  • 2021-08-07
  • 2021-08-29
  • 2021-12-22
  • 2022-01-07
猜你喜欢
  • 2019-01-21
  • 2022-12-23
  • 2021-06-19
  • 2021-12-05
相关资源
相似解决方案