线性判别分析
核心思想
线性判别分析(Linear Discriment Analysis, LDA)是一种经典的分类算法。核心思想是将训练数据投影到这样的一条直线上:
- 在该直线上的所有投影点,属于同一类别的数据尽可能的近,简称类内近。
- 在该直线上的所有投影点,属于不同类的数据尽可能的远,简称类间远。
当我们对新数据进行预测时,只需同样的将该数据投影到求得的直线上,通过判断,投影点哪一个类近,就认为是哪一类。
总结起来就是六个字:类内近,类间远。
数学推导
我们以二分类为例:
我们假设输入样本为(X,Y),样本总数为N,有两个类别C1和C2。
X=⎝⎜⎜⎜⎛x11,x12,⋯x1px21,x22,⋯,x2p⋮xn1,xN2,⋯,xNp⎠⎟⎟⎟⎞N×p
Y=(y1,y2,⋯,yN)⊤,y∈{C1,C2}
则所求直线可表示为:
y=W⊤x
其中,W=(w1,w2,⋯,wN)⊤
为方便运算,我们假定∥W∥=1。

其中样本点∥xi∥在直线y=w⊤x的投影距为:
∥xi∥⋅cosθ=∥xi∥⋅∥w∥⋅cosθ=ω⊤xi
我们以投影距作为样本点xi在直线y=w⊤x的一维坐标。
类间距我们用样本的方差均值表示,类内距我们用样本方差表示。
样本均值:yˉ=N1i=1∑Nw⊤xi
样本方差:S=N1i=1∑N(w⊤xi−yˉ)(w⊤xi−yˉ)⊤
类间距为:(yˉ1−yˉ2)2
类内距为:S1+S2
优化函数为:J(w)=S1+S2(y1ˉ−yˉ2)2
w=argmaxwJ(w)=argmaxwS1+S2(yˉ1−yˉ2)
其中:
(yˉ1−yˉ2)2=(N11i=1∑N1w⊤xi−N21i=1∑N1w⊤xi)2=(w⊤Ni1i=1∑N1xi−ωii=1∑N2xi)2=w⊤(xˉc1−xˉc2)(xˉc1−xˉc2)⊤w
S1=N11i=1∑N1(ωjˉxi−yˉ1)(ω⊤xi−yˉ1)⊤=N11i=1∑N1(w⊤xi−)1N1i=1∑Nw⊤xi)⋅(ωxi−N11i=1∑Nw⊤xi=N11i=1∑N1w⊤(xi−N11i=1∑Nxi)⋅(xi−N11i=1∑Nxi)⊤⋅w=w⊤N11i=1∑N(xi−xˉ1)(xi−xˉ1)⊤⋅w=w⊤SC1w
同理S2=w⊤SC2w
故 :
J(w)=S1+S2(y1ˉ−yˉ2)2=w⊤(SC1+SC2)ww⊤(xˉc1−xˉc2)(xˉc1−xˉc2)⊤w
而其中(xˉc1−xˉc2)(xˉc1−xˉc2)⊤和(SC1+SC2)与w无关,为方便运算,我们令Sb=(xˉc1−xˉc2)(xˉc1−xˉc2)⊤, Sw=SC1+SC2。
则J(w)=w⊤Swww⊤Sbw
对w进行求导:
∂ω∂J(ω)=∂w∂ω⊤Sbω(ω⊤Swω)−1=2Sbw(w⊤Sww)−1+w⊤Sbw(−1)(w⊤Sww)−2⋅2Sωw=0
两边同乘以(w⊤Sww)2,得:
Sbw(w⊤Sww)−w⊤Sbw⋅Sww=0
Sbww⊤Sww=w⊤SbwSww
其中:Sb和Sw均是(p * p)维的,而w是(p * 1)维的,
则w⊤Sww 和 w⊤Sbw的结果均为1维常量,故可直接消去,不妨设为常数C。则
Sww=w⊤Sbww⊤SwwSbw
w=CSv−1Sbw
w∝Sw−1Sbw
w∝Sw−1(xˉc1−xˉc2)(xˉc1−xˉc2)⊤w
同理(xˉc1−xˉc2)⊤w 是一个常数。
故最终可获得:
w∝Sw−1(xˉc1−xˉc2)w∝(Sc1+Sc2)−1(xˉ1−xˉ2)
即我们计算得到了w的方向,从而也就知道了分类的超平面(垂直于该直线),便可进行分类任务。
算法实现
待更新。
参考文献
[1] 机器学习
[2] https://github.com/shuhuai007/Machine-Learning-Session
[3] 图解机器学习