模型剪枝原理

文献通过修剪网络中不重要的权值,减小网络参数,主要分三个步骤:

Iterative Pruning

(1)正常训练整个网络.

(2)修剪不重要的连接,一般认为权值比较大的比较重要,因此设定一个阈值,剪去小于阈值的权值.

(3)重新训练网络,采用稀疏矩阵保存网络参数.

Iterative Pruning

源码分析:

以MNIST CNN 模型为例,下载github源码:

git clone https://github.com/garion9013/impl-pruning-TF

训练:

python train.py -2 -3

训练的时候会加载已经训练好的模型:model_ckpt_dense.

之后对全连接层fc1,fc2进行剪枝:

def apply_prune(weights):
    dict_nzidx = {}
    for target in papl.config.target_layer:
        wl = "w_" + target
        print(wl + " threshold:\t" + str(papl.config.th[wl]))
        # Get target layer's weights
        weight_obj = weights[wl]
        weight_arr = weight_obj.eval()
        # Apply pruning
        weight_arr, w_nzidx, w_nnz = papl.prune_dense(weight_arr, name=wl,
                                            thresh=papl.config.th[wl])
        # Store pruned weights as tensorflow objects
        dict_nzidx[wl] = w_nzidx
        sess.run(weight_obj.assign(weight_arr))
    return dict_nzidx

剪之后全连接层权值矩阵weight小于阈值的权值为0,保存剪枝后的模型:model_ckpt_dense_pruned.

对剪枝后的模型重新训练:

梯度的计算为,对全连接层的梯度计算为,只保留权值矩阵weight大于阈值处的梯度,


def apply_prune_on_grads(grads_and_vars, dict_nzidx):
    # Mask gradients with pruned elements
    for key, nzidx in dict_nzidx.items():
        count = 0
        for grad, var in grads_and_vars:
            if var.name == key+":0":
                nzidx_obj = tf.cast(tf.constant(nzidx), tf.float32)
                grads_and_vars[count] = (tf.multiply(nzidx_obj, grad), var)
            count += 1
    return grads_and_vars

训练后会得到模型:model_ckpt_dense_retrained.

再次剪枝,即另全连接层的权值矩阵weight小于阈值的权值为0,得到weight1,将weight1用稀疏矩阵保存,即计算weight1中不为0的value,及其对应的index.

剪枝,


def prune_tf_sparse(weight_arr, name="None", thresh=0.005):
    assert isinstance(weight_arr, np.ndarray)

    under_threshold = abs(weight_arr) < thresh
    weight_arr[under_threshold] = 0
    values = weight_arr[weight_arr != 0]
    indices = np.transpose(np.nonzero(weight_arr))
    shape = list(weight_arr.shape)

    count = np.sum(under_threshold)
    print "Non-zero count (Sparse %s): %s" % (name, weight_arr.size - count)
    return [indices, values, shape]

获得稀疏矩阵:

def gen_sparse_dict(dense_w):
    sparse_w = dense_w
    for target in papl.config.target_all_layer:
        target_arr = np.transpose(dense_w[target].eval())
        sparse_arr = papl.prune_tf_sparse(target_arr, name=target)
        sparse_w[target+"_idx"]=tf.Variable(tf.constant(sparse_arr[0],dtype=tf.int32),
            name=target+"_idx")
        sparse_w[target]=tf.Variable(tf.constant(sparse_arr[1],dtype=tf.float32),
            name=target)
        sparse_w[target+"_shape"]=tf.Variable(tf.constant(sparse_arr[2],dtype=tf.int32),
            name=target+"_shape")
    return sparse_w

之后将模型保存为model_ckpt_sparse_retrained.

迭代10次,模型有原有的13M,下降到3.9M,压缩70%,test accuracy 0.9708

Iterative Pruning

测试效果:

原有模型:

python deploy_test.py -d -m model_ckpt_dense

Iterative Pruning

压缩模型:

python deploy_test_pruned.py -d -m model_ckpt_sparse_retrained

Iterative Pruning

相关文章: