开源脉冲神经网络深度学习框架——惊蛰(SpikingJelly)
开源脉冲神经网络深度学习框架——惊蛰(SpikingJelly)
背景
近年来神经形态计算芯片发展迅速,大量高校企业团队跟进,这样的芯片运行SNN的能效比与速度都超越了传统的通用计算设备。相应的,神经形态感知芯片也发展迅速。目前已有各种模态的感知芯片,其中如北京大学黄铁军教授团队的Vidar相机,功能上仿照视网膜中央凹,能输出脉冲信号,高速情况下实现比传统相机更清晰的采样。脉冲网络研究领域顶会文章与Nature Science刊物文章也在逐年增长(如下图)。通过ANN转换SNN,SNN首次达到媲美ANN的性能。同时,随着梯度替代(surrogate gradient)法的提出,直接利用梯度下降法进行SNN训练成为可能。目前,利用SNN进行深度学习已经成为机器学习领域的研究热点,在视觉分类,目标检测,强化学习等领域取得了不错的成果。
*仅计算Research与Reviews文章
SpikingJelly的前身SpikingFlow,作为北京大学本科生《神经网络的计算基础》& 研究生《机器学习原理》课程教学实验平台,已经有将近一年的时间。SpikingFlow与SpikingJelly的关系犹如Minix之于Linux,前者作为教学用的示例,后者在前者的基础上改进,吸纳教学过程中的意见,新的工程实践和研究成果。
如上图所示,目前已有的机器学习框架和脉冲网络仿真框架都无法彼此兼容,传统的脉冲网络框架无法与当下成熟的深度学习技术相结合。因此,SpikingJelly也就应运而生。同时期,国外也诞生了类似的Norse框架,两者达成了良性互动,共同推进了该领域的发展。
图源: Deep learning incorporating biologically inspired neural dynamics and in-memory computing
脉冲神经网络介绍
脉冲神经网络,简称SNN,被誉为第三代人工神经网络,是由大脑这样一个脉冲信号处理系统启发而构建的。大脑具有高级智能,并且功耗较低(有一种说法是仅相当于一个25w的灯泡)。通过借鉴大脑中的脉冲结构,SNN能够在保持低功耗的前提下,达到与ANN相当的性能。
SNN的结构仿照了生物神经系统的组织结构,现在使用的深度SNN通常采用前向结构(如下图所示)。相比前向ANN,SNN的复杂性不仅体现在连接权重与拓扑的多样性,还体现在神经元内在的动力学方程上,能够同时处理时间域与空间域的信息。
基于SNN的诸多特性,其应用前景十分广阔,从简单的图像分类,动作识别,音频处理到复杂的音视频信号,强化学习,机器人控制任务,都是SNN施展自己本领的潜在舞台。
图源: TrueNorth: Accelerating From Zero to 64 Million Neurons in 10 Years
框架概况
惊蛰(SpikingJelly)是一个开源脉冲神经网络深度学习框架(框架主页:https://github.com/fangwei123456/spikingjelly (GitHub);https://git.openi.org.cn/OpenI/spikingjelly (OpenI))。SpikingJelly框架整体结构图如下,使用PyTorch作为自动微分后端,利用C++和CUDA扩展进行性能增强,同时支持CPU和GPU。框架中包含数据集,可视化,深度学习三大模块。目前社区主要由北大媒体学习组和鹏城实验室人工智能中心运营管理。
在SpikingJelly框架中,神经元的动态被描述为充电,放电,重置三个过程,与图中三种颜色分别对应。
框架支持梯度替代法与ANN转SNN法两种主流SNN的深度学习算法,也是在实际任务上目前性能最好的两种算法。之前介绍的神经元动态方程中,只有放电过程是不可微分的,为了解决这个问题,近年来领域内提出了替代梯度方法直接训练SNN,原理是梯度在反向传播时采用替代函数来近似放电过程中的阶跃函数。另一种深度学习方法,从ANN转换为相应的SNN,也受到了重点关注。这是由于ANN中的ReLU神经元非线性激活和SNN中IF神经元的发放率有着极强的相关性,可以基于类似的相关性将训练好的ANN转换为对应的SNN,省去了直接训练SNN的难题。
框架提供了Neuron,Layer,Functional,Encoding四个基本模块:
- Neuron提供了深度SNN中最常用的LIF和IF神经元;
- Layer提供了SNN中特有的网络层;
- Functional提供深度SNN所需的函数;
- Encoding提供常用的脉冲编码器。
框架还集成了多个神经形态数据集,因为大多数神经形态数据集需要专用软件读取,使用非常繁琐。因此SpikingJelly将常用神经形态数据集统一进行了封装,只需一行代码即可进行调用,其中事件数据被统一为(