目录
关于TransE,博客上各种博文漫天飞,对于原理我就不做重复性劳动。可以自己去搜,了解原理个人觉得以下两篇足够了:
算法伪代码
合页损失函数
SGD中的向量更新
前辈们的博文中对原理已经说的比较清楚了,但是对于SGD(stochastic gradient descent)随机梯度下降中的向量更新,都几乎没有过多讲解,简单的说一下我的理解。
随机梯度下降我们首先可以简化成梯度下降,随机是体现在sample时的。根据吴恩达老师机器学习的课程梯度下降可知,每一次对θ的更新如下(α是学习率):
线性回归的损失函数如下:
损失函数可以推广到更一般的形式,在TransE中我们的损失函数如下:
对于一般的数学表达式,不管是导数和偏导数我们都比较熟悉。以导数为例
对于一般函数,f(x)=x2,则f’(x)=(x2)’=2x*(dx)=2x
对于TransE中的损失函数,首先进行简化,去掉[]+和前面的两个sigma求和,简化成如下形式。
f(h,l,t,h’,t’)=(d(h+l-t)+d(h’+l-t’))注意这里面的五个参数都是向量,那么对h求偏导,
我们看目标函数,把h视为向量时,d(h+l-t)=h+l-t。把h视为[x1,x2,x3,...,xn]这样的n个一维变量时,当距离函数即d函数采用L1范式,即曼哈顿距离,d函数可以进一步推导:
d(h+l-t)=fabs(h+l-t).sum()
假设h+l-t=[1,-2,3,-4],那么d=1+2+3+4=10
所以我们推出:
所以我们得到,最后的偏导数向量要等于类似[1,1,1,-1,1,-1,...,-1]这样的一个n维向量。特别说一下,更新的时候一定要乘上学习率,否则很有可能会不收敛,形成z字形震荡。
代码实现
https://github.com/haidfs/TransE
代码简要分析
case1:TrainTransESimple
先实现基本的功能,各项模型内参数与超参如截图所示,可见单进程单线程一次训练一个batch_size为10000的batch,速度非常慢,接近11s(这是在内存128G的Linux服务器上,个人PC会更慢),
case2:TrainTransEMpManager
11s一轮实在说不上快。。在不考虑物理外挂(gpu)的情况下,先考虑使用多进程,最开始不太理解多进程的使用方法,最初的思路是多进程共享变量,将TransE类的变量在多个子进程间传递,于是有了TrainTransEMpManager.py(不建议在个人PC上运行,会非常卡),在这里面将TransE类的实例通过Manager共享。这个速度相比于之前的for循环存在一定的提升,但是不如multiprocessing.Queue()带来的性能提升大。个人理解:如果类比于多线程,每次线程的切入切出总是需要记录上下文信息,大量的线程会造成线程颠簸,带来不必要的开销;Python的多进程应该也是类似,当进程间共享的类的对象存在很多属性,即占用很大的内存空间时,切入和切出同样会带来很大的开销。这样的多进程反而降低了性能,多进程,应该尽可能精简共享的内存大小。在Linux服务器上,同样的batch_size,manger多进程一轮的时间为4.7s,如下:
case3:TrainTransEMpQueue
再看多进程实现同样的参数配置,通过multiprocessing的Queue(),每次仅仅共享[batch_size,dim]维大小的向量,和同样的case1相比:速度还是提升了不少,每一轮仅耗时4.7s。运行到1000轮之后,每轮运行时间接近2s。
初步训练与测试结果:
可以发现初步结果与论文结果较接近,但是还有一定的差距,等待后续调参再训练。
| FB15k | ||||
|---|---|---|---|---|
| epochs:2000 | MeanRank | [email protected] | ||
| raw | filter | raw | filter | |
| head | 320.743 | 192.152 | 29.7 | 41.2 |
| tail | 236.984 | 153.431 | 36.1 | 46.2 |
| average | 278.863 | 172.792 | 32.9 | 43.7 |
| paper | 243 | 125 | 34.9 | 47.1 |
疑问:
进行测试时由于单个测试例需要利用整个测试集的所有测试例替换头尾实体,在代码里面写了TestTransEMpQueue和TestMainTF两个版本的测试代码,但是使用Queue()的多进程效果并不理想,几乎没有提升,单个测试例0.4s左右,接近5w个测试例约为5.5小时,速度实在太慢。。。但是不明确为什么,希望有明白的大神可以多多指教。
TestMainTF进行一次测试的耗时为420s左右,约7min。