VGG16
run/vgg16/vgg16_prune_demo.py运行:
python ./run/vgg16/vgg16_prune_demo.py --config ./run/vgg16/prune.json
报错:
Traceback (most recent call last):
File "./run/vgg16/vgg16_prune_demo.py", line 16, in <module>
from logger import logger
File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 67, in <module>
logger = Logger()
File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 42, in __init__
json.dump(cfg, fp)
File "/anaconda3/envs/deeplearning/lib/python3.7/json/__init__.py", line 179, in dump
for chunk in iterable:
File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 438, in _iterencode
o = _default(o)
File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 179, in default
raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type Config is not JSON serializable
原因是无法序列化某些对象格式,因为我们这里使用了自定义的dotdict
解决办法:
将logger.py中的json.dump()改为:
with open(self.cfgfile, 'w') as fp:
json.dump(cfg, fp, cls=dotdict)
显式指定使用自定义序列化方法dotdict
再出错:
AssertionError: Torch not compiled with CUDA enabled
将prune.json中的cuda:true改为false
报错:
FileNotFoundError: [Errno 2] No such file or directory: './logs/vgg16_cifar10/ckp.160.torch'
这是因为我没有按照顺序运行,没有先运行:
CUDA_VISIBLE_DEVICES=0 python main.py --config ./run/vgg16/baseline.json
该命令会生成一个ckp.160.torch文件
所以我使用pytorch给的预训练文件,将vgg16_prune_demo.py中的:
def get_pack():
set_seeds()
pack = recover_pack()
model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
pack.net.module.load_state_dict(model_dict)
改成:
def get_pack():
set_seeds()
pack = recover_pack()
pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)
然后查看此时的网络结果:
pack, GBNs = get_pack()
for name, child in pack.net.named_children():
print(name)
print(child)
print(GBNs)
后面运行出错:
Traceback (most recent call last):
File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
run()
File "./run/vgg16/vgg16_prune_demo.py", line 112, in run
pack, GBNs = get_pack()
File "./run/vgg16/vgg16_prune_demo.py", line 29, in get_pack
pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)
File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for VGG:
size mismatch for features.7.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
这是因为应该使用的结构是vgg16_bn的结构,否则就没有bn层,改模型为https://download.pytorch.org/models/vgg16_bn-6c64b313.pth
又报错:
Traceback (most recent call last):
File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
run()
File "./run/vgg16/vgg16_prune_demo.py", line 114, in run
cloned, _ = clone_model(pack.net)
File "./run/vgg16/vgg16_prune_demo.py", line 54, in clone_model
gbns = GatedBatchNorm2d.transform(model.module)
File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
type(self).__name__, name))
AttributeError: 'VGG' object has no attribute 'module'
把model.module改成model即可,因为我没有使用
if cfg.base.multi_gpus: #设置了multi_gpus为False model = torch.nn.DataParallel(model)
仅仅根据代码说说原理
感觉看了所有的代码后其工作原理是这样的,拿vgg16_prune_demo.py的prune()函数举例子:
prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
1)准备好了Tick-Tock
# 先所有数据迭代cfg.gbn.tock_epoch次
prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
其实就相当于在原有模型上进行微调cfg.gbn.tock_epoch次
2)然后就循环进行Tick操作:
def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM): LOGS = [] flops_save_points = set([30, 20, 10]) iter_idx = 0 pack.tick_trainset = pack.train_loader prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3) # 先所有数据迭代cfg.gbn.tock_epoch次 prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) while True: left_filter = prune_agent.total_filters - prune_agent.pruned_filters num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值 info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数 flops, params = eval_prune(pack) info.update({ #查看这次剪枝后的结果 'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6), 'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6) }) LOGS.append(info) print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc'])) iter_idx += 1 if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock print('Tocking:') prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少 for point in [i for i in list(flops_save_points)]: if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point)) flops_save_points.remove(point) if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了 break
Tick操作就是在计算分数,决定剪去BN层的哪些channels
3)开始进行Tick-Tock前的网络结构就是将BN层换成了GBN层:
def get_pack(): set_seeds() pack = recover_pack() #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda') pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False) #pack.net.module.load_state_dict(model_dict) GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练 for gbn in GBNs: gbn.extract_from_bn() # for name, child in pack.net.named_children(): # print(name) # print(child) pack.optimizer = optim.SGD( pack.net.parameters() , lr=2e-3, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov ) return pack, GBNs
GatedBatchNorm2d.transform(pack.net) 中的extract_from_bn()函数在bn层加入g参数,同时将其bias、weight参数进行更改,并freeze weight参数,这样训练时只有g参数会优化:
def extract_from_bn(self): # freeze bn weight with torch.no_grad(): self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10)) self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1)) self.bn.weight.set_(torch.ones_like(self.bn.weight)) self.bn.weight.requires_grad = False
如论文中:
在这个基础上进行Tock操作其实就是在bn层加入g参数,并freeze weight参数的基础上使用整个训练数据集训练模型
4)然后进行prune操作:
info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
其实就是进行Tick操作+prune操作
首先Tick操作是:
def tick(self, lr, test): ''' Do Prune ''' self.freeze_conv() info = self.recover(lr, test) self.restore_conv() return info
会freeze住卷积层的参数,所以tick训练时只会训练GBN层的g参数和全连接层的参数
接下来的就是剪枝prune操作:
然后接下来就是根据这个Tick训练的g计算每个bn层中filter的分数,一开始bn_mask(查看prune/universal.py文件中的类GatedBatchNorm2d定义)这个值全是1,即表示所有的filter都要,这样子self.score*self.bn_mask就能得到所有的filter的分数,然后再根据分数进行排序等操作来计算阈值分数值threshold,然后再根据阈值等信息得到一个self.mask的值,用这个值去更新self.bn_mask = mask * g.bn_mask,这样每个GBN层中的bn_mask值中为0就表示对应的filter是被删除的,1则表示该对应的filter留下
所以剪枝操作其实就是根据bn_mask的结果去剪枝,因为GatedBatchNorm2d类的forward操作中有:
def forward(self, x): x = self.bn(x) * self.g self.area[0] = x.shape[-1] * x.shape[-2] if self.bn_mask is not None: return x * self.bn_mask return x
因此在训练的时候,前向操作经过GBN层得到的结果就是x * self.bn_mask,bn_mask为0对应的x的channels的值就会全为0,就相当于剪掉了这个filter
5)接下来就是根据上面的剪枝结果去对应地将卷积层和全连接层中的channels数和GBN层对应起来:
_ = Conv2dObserver.transform(pack.net.module)
pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)
主要就是将它们分别封装成Conv2dObserver和FinalLinearObserver
Conv2dObserver中就会有in_mask和out_mask两个参数,就是分别在训练的前向传播和后向传播中计算channels轴的和,最后为0则说明该轴已经被prune了:
def _forward_hook(self, m, _in, _out): x = _in[0] self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) def _backward_hook(self, grad): self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) new_grad = torch.ones_like(grad) return new_grad def forward(self, x): output = self.conv(x) noise = torch.zeros_like(output).normal_() output = output + noise if self.training: output.register_hook(self._backward_hook) return output
FinalLinearObserver也是同样的概念
6)然后就是observe和melt_all操作:
Meltable.observe(pack, 0.001) Meltable.melt_all(pack.net)
observe感觉就是在将那些没有被换成GBN层的bn层的weight添加一个极小值(1e-3)、将relu层改成LeakyReLU并freeze bn层的参数,然后再进行训练,训练完之后再恢复原状(这里一直不太明白目的是啥)
突然明白这里是干嘛了,这里其实就是训练一遍,来计算Conv2dObserver和FinalLinearObserver中in_mask和out_mask的结果,然后用于melt_all
melt_all其实就是将所有的GBN、Conv2dObserver和FinalLinearObserver根据得到的in_mask和out_mask以及GBN中的self.bn_mask来恢复网络,删去不要的filter,只将对应的filter的参数赋值到新的网络结构中,调用的是这几个类中的melt()函数
7)最后再使用这个新的网络结构进行微调:
_ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)
要自己将微调后的模型保存下来
1》仅保存模型:
torch.save(pack.net.module.state_dict(), os.path.join(saving_path, '30_finetune_state.pth'))
用.module是因为使用了:
model = torch.nn.DataParallel(model)
如果没有使用可以删掉
2》保存模型和网络结构:
torch.save(pack.net.module, os.path.join(saving_path, '30_finetune.pth'))
整个代码是:
import os import sys _r = os.getcwd().split('/') _p = '/'.join(_r[:_r.index('gate-decorator-pruning')+1]) print('Change dir from %s to %s' % (os.getcwd(), _p)) os.chdir(_p) sys.path.append(_p) import torch import torch.nn as nn import numpy as np import torch.optim as optim from config import cfg from logger import logger from main import set_seeds, recover_pack, adjust_learning_rate, _step_lr, _sgdr from models import get_model from utils import dotdict from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver from prune.utils import analyse_model, finetune def get_pack(): set_seeds() pack = recover_pack() #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda') pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False) #pack.net.module.load_state_dict(model_dict) GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练 for gbn in GBNs: gbn.extract_from_bn() # for name, child in pack.net.named_children(): # print(name) # print(child) pack.optimizer = optim.SGD( pack.net.parameters() , lr=2e-3, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov ) return pack, GBNs # get_pack() def clone_model(net): model = get_model() gbns = GatedBatchNorm2d.transform(model) model.load_state_dict(net.state_dict()) return model, gbns def eval_prune(pack): cloned, _ = clone_model(pack.net) _ = Conv2dObserver.transform(cloned.module) #根据prune后的bn更改conv2d层 cloned.module.classifier = FinalLinearObserver(cloned.module.classifier) #根据prune后的bn更改全连接层 cloned_pack = dotdict(pack.copy()) cloned_pack.net = cloned Meltable.observe(cloned_pack, 0.001) Meltable.melt_all(cloned_pack.net) #根据此时的g恢复所有的参数 # flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda()) flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32)) del cloned del cloned_pack return flops, params def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM): LOGS = [] flops_save_points = set([30, 20, 10]) iter_idx = 0 pack.tick_trainset = pack.train_loader prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3) # 先所有数据迭代cfg.gbn.tock_epoch次 prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) while True: left_filter = prune_agent.total_filters - prune_agent.pruned_filters num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值 info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数 flops, params = eval_prune(pack) info.update({ #查看这次剪枝后的结果 'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6), 'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6) }) LOGS.append(info) print('Iter: %d,\t FLOPS: %s,\t Param: %s,\t Left: %d,\t Pruned Ratio: %.2f %%,\t Train Loss: %.4f,\t Test Acc: %.2f' % (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc'])) iter_idx += 1 if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock print('Tocking:') prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少 for point in [i for i in list(flops_save_points)]: if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point)) flops_save_points.remove(point) if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了 break def run(): pack, GBNs = get_pack() cloned, _ = clone_model(pack.net) # BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32).cuda()) #计算一开始预训练好的模型的Flops和内存 BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32)) print('%.3f MFLOPS' % (BASE_FLOPS / 1e6)) print('%.3f M' % (BASE_PARAM / 1e6)) del cloned prune(pack, GBNs, BASE_FLOPS, BASE_PARAM) # 进行Tick-Tock操作 _ = Conv2dObserver.transform(pack.net.module) pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier) Meltable.observe(pack, 0.001) Meltable.melt_all(pack.net) pack.optimizer = optim.SGD( pack.net.parameters(), lr=1, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov ) _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch) run()