模型剪枝原理
文献通过修剪网络中不重要的权值,减小网络参数,主要分三个步骤:
(1)正常训练整个网络.
(2)修剪不重要的连接,一般认为权值比较大的比较重要,因此设定一个阈值,剪去小于阈值的权值.
(3)重新训练网络,采用稀疏矩阵保存网络参数.
源码分析:
以MNIST CNN 模型为例,下载github源码:
git clone https://github.com/garion9013/impl-pruning-TF
训练:
python train.py -2 -3
训练的时候会加载已经训练好的模型:model_ckpt_dense.
之后对全连接层fc1,fc2进行剪枝:
def apply_prune(weights): dict_nzidx = {} for target in papl.config.target_layer: wl = "w_" + target print(wl + " threshold:\t" + str(papl.config.th[wl])) # Get target layer's weights weight_obj = weights[wl] weight_arr = weight_obj.eval() # Apply pruning weight_arr, w_nzidx, w_nnz = papl.prune_dense(weight_arr, name=wl, thresh=papl.config.th[wl]) # Store pruned weights as tensorflow objects dict_nzidx[wl] = w_nzidx sess.run(weight_obj.assign(weight_arr)) return dict_nzidx
剪之后全连接层权值矩阵weight小于阈值的权值为0,保存剪枝后的模型:model_ckpt_dense_pruned.
对剪枝后的模型重新训练:
梯度的计算为,对全连接层的梯度计算为,只保留权值矩阵weight大于阈值处的梯度,
def apply_prune_on_grads(grads_and_vars, dict_nzidx): # Mask gradients with pruned elements for key, nzidx in dict_nzidx.items(): count = 0 for grad, var in grads_and_vars: if var.name == key+":0": nzidx_obj = tf.cast(tf.constant(nzidx), tf.float32) grads_and_vars[count] = (tf.multiply(nzidx_obj, grad), var) count += 1 return grads_and_vars
训练后会得到模型:model_ckpt_dense_retrained.
再次剪枝,即另全连接层的权值矩阵weight小于阈值的权值为0,得到weight1,将weight1用稀疏矩阵保存,即计算weight1中不为0的value,及其对应的index.
剪枝,
def prune_tf_sparse(weight_arr, name="None", thresh=0.005): assert isinstance(weight_arr, np.ndarray) under_threshold = abs(weight_arr) < thresh weight_arr[under_threshold] = 0 values = weight_arr[weight_arr != 0] indices = np.transpose(np.nonzero(weight_arr)) shape = list(weight_arr.shape) count = np.sum(under_threshold) print "Non-zero count (Sparse %s): %s" % (name, weight_arr.size - count) return [indices, values, shape]
获得稀疏矩阵:
def gen_sparse_dict(dense_w): sparse_w = dense_w for target in papl.config.target_all_layer: target_arr = np.transpose(dense_w[target].eval()) sparse_arr = papl.prune_tf_sparse(target_arr, name=target) sparse_w[target+"_idx"]=tf.Variable(tf.constant(sparse_arr[0],dtype=tf.int32), name=target+"_idx") sparse_w[target]=tf.Variable(tf.constant(sparse_arr[1],dtype=tf.float32), name=target) sparse_w[target+"_shape"]=tf.Variable(tf.constant(sparse_arr[2],dtype=tf.int32), name=target+"_shape") return sparse_w
之后将模型保存为model_ckpt_sparse_retrained.
迭代10次,模型有原有的13M,下降到3.9M,压缩70%,test accuracy 0.9708
测试效果:
原有模型:
python deploy_test.py -d -m model_ckpt_dense
压缩模型:
python deploy_test_pruned.py -d -m model_ckpt_sparse_retrained