VGG16

run/vgg16/vgg16_prune_demo.py运行:

 python ./run/vgg16/vgg16_prune_demo.py --config ./run/vgg16/prune.json

报错:

Traceback (most recent call last):
  File "./run/vgg16/vgg16_prune_demo.py", line 16, in <module>
    from logger import logger
  File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 67, in <module>
    logger = Logger()
  File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 42, in __init__
    json.dump(cfg, fp)
  File "/anaconda3/envs/deeplearning/lib/python3.7/json/__init__.py", line 179, in dump
    for chunk in iterable:
  File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type Config is not JSON serializable

原因是无法序列化某些对象格式,因为我们这里使用了自定义的dotdict

解决办法:

将logger.py中的json.dump()改为:

            with open(self.cfgfile, 'w') as fp:
                json.dump(cfg, fp, cls=dotdict)

显式指定使用自定义序列化方法dotdict

再出错:

AssertionError: Torch not compiled with CUDA enabled

将prune.json中的cuda:true改为false

报错:

FileNotFoundError: [Errno 2] No such file or directory: './logs/vgg16_cifar10/ckp.160.torch'

这是因为我没有按照顺序运行,没有先运行:

CUDA_VISIBLE_DEVICES=0 python main.py --config ./run/vgg16/baseline.json

该命令会生成一个ckp.160.torch文件

 

所以我使用pytorch给的预训练文件,将vgg16_prune_demo.py中的:

def get_pack():
    set_seeds()
    pack = recover_pack()

    model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
    pack.net.module.load_state_dict(model_dict)

改成:

def get_pack():
    set_seeds()
    pack = recover_pack()
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)

然后查看此时的网络结果:

pack, GBNs = get_pack()
for name, child in pack.net.named_children():
    print(name)
    print(child)

print(GBNs)

 

后面运行出错:

Traceback (most recent call last):
  File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
    run()
  File "./run/vgg16/vgg16_prune_demo.py", line 112, in run
    pack, GBNs = get_pack()
  File "./run/vgg16/vgg16_prune_demo.py", line 29, in get_pack
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)
  File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for VGG:
    size mismatch for features.7.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).

这是因为应该使用的结构是vgg16_bn的结构,否则就没有bn层,改模型为https://download.pytorch.org/models/vgg16_bn-6c64b313.pth

 

又报错:

Traceback (most recent call last):
  File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
    run()
  File "./run/vgg16/vgg16_prune_demo.py", line 114, in run
    cloned, _ = clone_model(pack.net)
  File "./run/vgg16/vgg16_prune_demo.py", line 54, in clone_model
    gbns = GatedBatchNorm2d.transform(model.module)
  File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
    type(self).__name__, name))
AttributeError: 'VGG' object has no attribute 'module'

model.module改成model即可,因为我没有使用

    if cfg.base.multi_gpus: #设置了multi_gpus为False
        model = torch.nn.DataParallel(model)

 

 

仅仅根据代码说说原理

感觉看了所有的代码后其工作原理是这样的,拿vgg16_prune_demo.py的prune()函数举例子:

prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)

1)准备好了Tick-Tock

# 先所有数据迭代cfg.gbn.tock_epoch次
    prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

其实就相当于在原有模型上进行微调cfg.gbn.tock_epoch次

 

2)然后就循环进行Tick操作:

def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM):
    LOGS = []
    flops_save_points = set([30, 20, 10])
    iter_idx = 0

    pack.tick_trainset = pack.train_loader
    prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
    # 先所有数据迭代cfg.gbn.tock_epoch次
    prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    while True:
        left_filter = prune_agent.total_filters - prune_agent.pruned_filters
        num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值
        info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
        flops, params = eval_prune(pack)
        info.update({ #查看这次剪枝后的结果
            'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
            'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
        })
        LOGS.append(info)
        print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % 
            (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
        
        iter_idx += 1
        if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock
            print('Tocking:')
            prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

        flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少
        for point in [i for i in list(flops_save_points)]:
            if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point
                torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point))
                flops_save_points.remove(point)

        if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了
            break

Tick操作就是在计算分数,决定剪去BN层的哪些channels

 

3)开始进行Tick-Tock前的网络结构就是将BN层换成了GBN层:

def get_pack():
    set_seeds()
    pack = recover_pack()

    #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False)
    #pack.net.module.load_state_dict(model_dict)

    
    GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练
    for gbn in GBNs:
        gbn.extract_from_bn()
        
#     for name, child in pack.net.named_children():
#         print(name)
#         print(child)
        
    pack.optimizer = optim.SGD(
        pack.net.parameters() ,
        lr=2e-3,
        momentum=cfg.train.momentum,
        weight_decay=cfg.train.weight_decay,
        nesterov=cfg.train.nesterov
    )

    return pack, GBNs

GatedBatchNorm2d.transform(pack.net) 中的extract_from_bn()函数在bn层加入g参数,同时将其bias、weight参数进行更改,并freeze weight参数,这样训练时只有g参数会优化:

    def extract_from_bn(self):
        # freeze bn weight
        with torch.no_grad():
            self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10))
            self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
            self.bn.weight.set_(torch.ones_like(self.bn.weight))
            self.bn.weight.requires_grad = False

如论文中:

Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks - 模型压缩 - 3 - 代码学习,VGG16,Resnet

 

在这个基础上进行Tock操作其实就是在bn层加入g参数,并freeze weight参数的基础上使用整个训练数据集训练模型

 

4)然后进行prune操作:

info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数

其实就是进行Tick操作+prune操作

首先Tick操作是:

    def tick(self, lr, test):
        ''' Do Prune '''
        self.freeze_conv()
        info = self.recover(lr, test)
        self.restore_conv()
        return info

会freeze住卷积层的参数,所以tick训练时只会训练GBN层的g参数和全连接层的参数

接下来的就是剪枝prune操作:

然后接下来就是根据这个Tick训练的g计算每个bn层中filter的分数,一开始bn_mask(查看prune/universal.py文件中的类GatedBatchNorm2d定义)这个值全是1,即表示所有的filter都要,这样子self.score*self.bn_mask就能得到所有的filter的分数,然后再根据分数进行排序等操作来计算阈值分数值threshold,然后再根据阈值等信息得到一个self.mask的值,用这个值去更新self.bn_mask = mask * g.bn_mask,这样每个GBN层中的bn_mask值中为0就表示对应的filter是被删除的,1则表示该对应的filter留下

所以剪枝操作其实就是根据bn_mask的结果去剪枝,因为GatedBatchNorm2d类的forward操作中有:

    def forward(self, x): 
        x = self.bn(x) * self.g

        self.area[0] = x.shape[-1] * x.shape[-2]

        if self.bn_mask is not None:
            return x * self.bn_mask
        return x

因此在训练的时候,前向操作经过GBN层得到的结果就是x * self.bn_maskbn_mask为0对应的x的channels的值就会全为0,就相当于剪掉了这个filter

 

5)接下来就是根据上面的剪枝结果去对应地将卷积层和全连接层中的channels数和GBN层对应起来:

    _ = Conv2dObserver.transform(pack.net.module)
    pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)

主要就是将它们分别封装成Conv2dObserver和FinalLinearObserver

Conv2dObserver中就会有in_mask和out_mask两个参数,就是分别在训练的前向传播和后向传播中计算channels轴的和,最后为0则说明该轴已经被prune了:

    def _forward_hook(self, m, _in, _out):
        x = _in[0]
        self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)

    def _backward_hook(self, grad):
        self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
        new_grad = torch.ones_like(grad)
        return new_grad

    def forward(self, x):
        output = self.conv(x)
        noise = torch.zeros_like(output).normal_()
        output = output + noise
        if self.training:
            output.register_hook(self._backward_hook)
        return output

FinalLinearObserver也是同样的概念

 

6)然后就是observe和melt_all操作:

    Meltable.observe(pack, 0.001)
    Meltable.melt_all(pack.net)

observe感觉就是在将那些没有被换成GBN层的bn层的weight添加一个极小值(1e-3)、将relu层改成LeakyReLU并freeze bn层的参数,然后再进行训练,训练完之后再恢复原状(这里一直不太明白目的是啥)

突然明白这里是干嘛了,这里其实就是训练一遍,来计算Conv2dObserver和FinalLinearObserver中in_mask和out_mask的结果,然后用于melt_all

 

melt_all其实就是将所有的GBN、Conv2dObserver和FinalLinearObserver根据得到的in_mask和out_mask以及GBN中的self.bn_mask来恢复网络,删去不要的filter,只将对应的filter的参数赋值到新的网络结构中,调用的是这几个类中的melt()函数

 

7)最后再使用这个新的网络结构进行微调:

    _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

 

要自己将微调后的模型保存下来

1》仅保存模型:

torch.save(pack.net.module.state_dict(), os.path.join(saving_path, '30_finetune_state.pth'))

用.module是因为使用了:

model = torch.nn.DataParallel(model)

如果没有使用可以删掉

2》保存模型和网络结构:

torch.save(pack.net.module, os.path.join(saving_path, '30_finetune.pth'))

 

整个代码是:

import os
import sys

_r = os.getcwd().split('/')
_p = '/'.join(_r[:_r.index('gate-decorator-pruning')+1])
print('Change dir from %s to %s' % (os.getcwd(), _p))
os.chdir(_p)
sys.path.append(_p)

import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim

from config import cfg
from logger import logger
from main import set_seeds, recover_pack, adjust_learning_rate, _step_lr, _sgdr
from models import get_model
from utils import dotdict

from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver
from prune.utils import analyse_model, finetune

def get_pack():
    set_seeds()
    pack = recover_pack()

    #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
    pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False)
    #pack.net.module.load_state_dict(model_dict)

    
    GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练
    for gbn in GBNs:
        gbn.extract_from_bn()
        
#     for name, child in pack.net.named_children():
#         print(name)
#         print(child)
        
    pack.optimizer = optim.SGD(
        pack.net.parameters() ,
        lr=2e-3,
        momentum=cfg.train.momentum,
        weight_decay=cfg.train.weight_decay,
        nesterov=cfg.train.nesterov
    )

    return pack, GBNs
# get_pack()

def clone_model(net):
    model = get_model()
    gbns = GatedBatchNorm2d.transform(model)
    model.load_state_dict(net.state_dict())
    return model, gbns


def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module) #根据prune后的bn更改conv2d层
    cloned.module.classifier = FinalLinearObserver(cloned.module.classifier) #根据prune后的bn更改全连接层
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net) #根据此时的g恢复所有的参数
#     flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
    flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32))
    del cloned
    del cloned_pack
    
    return flops, params


def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM):
    LOGS = []
    flops_save_points = set([30, 20, 10])
    iter_idx = 0

    pack.tick_trainset = pack.train_loader
    prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
    # 先所有数据迭代cfg.gbn.tock_epoch次
    prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    while True:
        left_filter = prune_agent.total_filters - prune_agent.pruned_filters
        num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值
        info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
        flops, params = eval_prune(pack)
        info.update({ #查看这次剪枝后的结果
            'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
            'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
        })
        LOGS.append(info)
        print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % 
            (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
        
        iter_idx += 1
        if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock
            print('Tocking:')
            prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

        flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少
        for point in [i for i in list(flops_save_points)]:
            if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point
                torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point))
                flops_save_points.remove(point)

        if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了
            break


def run():
    pack, GBNs = get_pack()

    cloned, _ = clone_model(pack.net)
#     BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32).cuda()) #计算一开始预训练好的模型的Flops和内存
    BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32))
    print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
    print('%.3f M' % (BASE_PARAM / 1e6))
    del cloned

    prune(pack, GBNs, BASE_FLOPS, BASE_PARAM) # 进行Tick-Tock操作

    _ = Conv2dObserver.transform(pack.net.module)
    pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)
    Meltable.observe(pack, 0.001)
    Meltable.melt_all(pack.net)

    pack.optimizer = optim.SGD(
        pack.net.parameters(),
        lr=1,
        momentum=cfg.train.momentum,
        weight_decay=cfg.train.weight_decay,
        nesterov=cfg.train.nesterov
    )

    _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

run()
View Code

相关文章: