【问题标题】:Convert torch t7 model to pytorch将torch t7模型转换为pytorch
【发布时间】:2019-02-01 18:43:41
【问题描述】:

我有一个 torch t7 模型,我想将其转换为 pytorch 模型。我用了这个方法:

model = load_lua('xxx.t7', unknown_classes=True)

但是,我收到以下错误:

AttributeError: type object 'torch.cuda.FloatStorage' has no attribute 'from_buffer'

知道怎么解决吗?

【问题讨论】:

    标签: lua pytorch torch


    【解决方案1】:

    有一个非常有用的转换器。我用了很多时间。

    如何使用;创建一个 convert_torch.py​​ 文件并将下面的代码粘贴到其中。然后使用 .t7 参数运行代码。

    python convert_torch.py​​ -m xxx.t7

    from __future__ import print_function
    
    import os
    import math
    import torch
    import argparse
    import numpy as np
    import torch.nn as nn
    import torch.optim as optim
    import torch.legacy.nn as lnn
    import torch.nn.functional as F
    
    from functools import reduce
    from torch.autograd import Variable
    from torch.utils.serialization import load_lua
    
    class LambdaBase(nn.Sequential):
        def __init__(self, fn, *args):
            super(LambdaBase, self).__init__(*args)
            self.lambda_func = fn
    
        def forward_prepare(self, input):
            output = []
            for module in self._modules.values():
                output.append(module(input))
            return output if output else input
    
    class Lambda(LambdaBase):
        def forward(self, input):
            return self.lambda_func(self.forward_prepare(input))
    
    class LambdaMap(LambdaBase):
        def forward(self, input):
            # result is Variables list [Variable1, Variable2, ...]
            return list(map(self.lambda_func,self.forward_prepare(input)))
    
    class LambdaReduce(LambdaBase):
        def forward(self, input):
            # result is a Variable
            return reduce(self.lambda_func,self.forward_prepare(input))
    
    
    def copy_param(m,n):
        if m.weight is not None: n.weight.data.copy_(m.weight)
        if m.bias is not None: n.bias.data.copy_(m.bias)
        if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean)
        if hasattr(n,'running_var'): n.running_var.copy_(m.running_var)
    
    def add_submodule(seq, *args):
        for n in args:
            seq.add_module(str(len(seq._modules)),n)
    
    def lua_recursive_model(module,seq):
        for m in module.modules:
            name = type(m).__name__
            real = m
            if name == 'TorchObject':
                name = m._typename.replace('cudnn.','')
                m = m._obj
    
            if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
                if not hasattr(m,'groups') or m.groups is None: m.groups=1
                n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None))
                copy_param(m,n)
                add_submodule(seq,n)
            elif name == 'SpatialBatchNormalization':
                n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
                copy_param(m,n)
                add_submodule(seq,n)
            elif name == 'VolumetricBatchNormalization':
                n = nn.BatchNorm3d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
                copy_param(m, n)
                add_submodule(seq, n)
            elif name == 'ReLU':
                n = nn.ReLU()
                add_submodule(seq,n)
            elif name == 'Sigmoid':
                n = nn.Sigmoid()
                add_submodule(seq,n)
            elif name == 'SpatialMaxPooling':
                n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
                add_submodule(seq,n)
            elif name == 'SpatialAveragePooling':
                n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
                add_submodule(seq,n)
            elif name == 'SpatialUpSamplingNearest':
                n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
                add_submodule(seq,n)
            elif name == 'View':
                n = Lambda(lambda x: x.view(x.size(0),-1))
                add_submodule(seq,n)
            elif name == 'Reshape':
                n = Lambda(lambda x: x.view(x.size(0),-1))
                add_submodule(seq,n)
            elif name == 'Linear':
                # Linear in pytorch only accept 2D input
                n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )
                n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None))
                copy_param(m,n2)
                n = nn.Sequential(n1,n2)
                add_submodule(seq,n)
            elif name == 'Dropout':
                m.inplace = False
                n = nn.Dropout(m.p)
                add_submodule(seq,n)
            elif name == 'SoftMax':
                n = nn.Softmax()
                add_submodule(seq,n)
            elif name == 'Identity':
                n = Lambda(lambda x: x) # do nothing
                add_submodule(seq,n)
            elif name == 'SpatialFullConvolution':
                n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))
                copy_param(m,n)
                add_submodule(seq,n)
            elif name == 'VolumetricFullConvolution':
                n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)
                copy_param(m,n)
                add_submodule(seq, n)
            elif name == 'SpatialReplicationPadding':
                n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
                add_submodule(seq,n)
            elif name == 'SpatialReflectionPadding':
                n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
                add_submodule(seq,n)
            elif name == 'Copy':
                n = Lambda(lambda x: x) # do nothing
                add_submodule(seq,n)
            elif name == 'Narrow':
                n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a))
                add_submodule(seq,n)
            elif name == 'SpatialCrossMapLRN':
                lrn = lnn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k)
                n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data)))
                add_submodule(seq,n)
            elif name == 'Sequential':
                n = nn.Sequential()
                lua_recursive_model(m,n)
                add_submodule(seq,n)
            elif name == 'ConcatTable': # output is list
                n = LambdaMap(lambda x: x)
                lua_recursive_model(m,n)
                add_submodule(seq,n)
            elif name == 'CAddTable': # input is list
                n = LambdaReduce(lambda x,y: x+y)
                add_submodule(seq,n)
            elif name == 'Concat':
                dim = m.dimension
                n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim))
                lua_recursive_model(m,n)
                add_submodule(seq,n)
            elif name == 'TorchObject':
                print('Not Implement',name,real._typename)
            else:
                print('Not Implement',name)
    
    
    def lua_recursive_source(module):
        s = []
        for m in module.modules:
            name = type(m).__name__
            real = m
            if name == 'TorchObject':
                name = m._typename.replace('cudnn.','')
                m = m._obj
    
            if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
                if not hasattr(m,'groups') or m.groups is None: m.groups=1
                s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane,
                    m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)]
            elif name == 'SpatialBatchNormalization':
                s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
            elif name == 'VolumetricBatchNormalization':
                s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
            elif name == 'ReLU':
                s += ['nn.ReLU()']
            elif name == 'Sigmoid':
                s += ['nn.Sigmoid()']
            elif name == 'SpatialMaxPooling':
                s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
            elif name == 'SpatialAveragePooling':
                s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
            elif name == 'SpatialUpSamplingNearest':
                s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)]
            elif name == 'View':
                s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View']
            elif name == 'Reshape':
                s += ['Lambda(lambda x: x.view(x.size(0),-1)), # Reshape']
            elif name == 'Linear':
                s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
                s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None))
                s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)]
            elif name == 'Dropout':
                s += ['nn.Dropout({})'.format(m.p)]
            elif name == 'SoftMax':
                s += ['nn.Softmax()']
            elif name == 'Identity':
                s += ['Lambda(lambda x: x), # Identity']
            elif name == 'SpatialFullConvolution':
                s += ['nn.ConvTranspose2d({},{},{},{},{},{})'.format(m.nInputPlane,
                    m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))]
            elif name == 'VolumetricFullConvolution':
                s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})'.format(m.nInputPlane,
                    m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)]
            elif name == 'SpatialReplicationPadding':
                s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
            elif name == 'SpatialReflectionPadding':
                s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
            elif name == 'Copy':
                s += ['Lambda(lambda x: x), # Copy']
            elif name == 'Narrow':
                s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))]
            elif name == 'SpatialCrossMapLRN':
                lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k))
                s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)]
    
            elif name == 'Sequential':
                s += ['nn.Sequential( # Sequential']
                s += lua_recursive_source(m)
                s += [')']
            elif name == 'ConcatTable':
                s += ['LambdaMap(lambda x: x, # ConcatTable']
                s += lua_recursive_source(m)
                s += [')']
            elif name == 'CAddTable':
                s += ['LambdaReduce(lambda x,y: x+y), # CAddTable']
            elif name == 'Concat':
                dim = m.dimension
                s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)]
                s += lua_recursive_source(m)
                s += [')']
            else:
                s += '# ' + name + ' Not Implement,\n'
        s = map(lambda x: '\t{}'.format(x),s)
        return s
    
    def simplify_source(s):
        s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s)
        s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s)
        s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s)
        s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s)
        s = map(lambda x: x.replace('),#Conv2d',')'),s)
        s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s)
        s = map(lambda x: x.replace('),#BatchNorm2d',')'),s)
        s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s)
        s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s)
        s = map(lambda x: x.replace('),#MaxPool2d',')'),s)
        s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s)
        s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s)
        s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s)
        s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s)
    
        s = map(lambda x: '{},\n'.format(x),s)
        s = map(lambda x: x[1:],s)
        s = reduce(lambda x,y: x+y, s)
        return s
    
    def torch_to_pytorch(t7_filename,outputname=None):
        model = load_lua(t7_filename,unknown_classes=True)
        if type(model).__name__=='hashable_uniq_dict': model=model.model
        model.gradInput = None
        slist = lua_recursive_source(lnn.Sequential().add(model))
        s = simplify_source(slist)
        header = '''
    import torch
    import torch.nn as nn
    import torch.legacy.nn as lnn
    from functools import reduce
    from torch.autograd import Variable
    class LambdaBase(nn.Sequential):
        def __init__(self, fn, *args):
            super(LambdaBase, self).__init__(*args)
            self.lambda_func = fn
        def forward_prepare(self, input):
            output = []
            for module in self._modules.values():
                output.append(module(input))
            return output if output else input
    class Lambda(LambdaBase):
        def forward(self, input):
            return self.lambda_func(self.forward_prepare(input))
    class LambdaMap(LambdaBase):
        def forward(self, input):
            return list(map(self.lambda_func,self.forward_prepare(input)))
    class LambdaReduce(LambdaBase):
        def forward(self, input):
            return reduce(self.lambda_func,self.forward_prepare(input))
    '''
        varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
        s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
    
        if outputname is None: outputname=varname
        with open(outputname+'.py', "w") as pyfile:
            pyfile.write(s)
    
        n = nn.Sequential()
        lua_recursive_model(model,n)
        torch.save(n.state_dict(),outputname+'.pth')
    
    
    parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
    parser.add_argument('--model','-m', type=str, required=True,
                        help='torch model file in t7 format')
    parser.add_argument('--output', '-o', type=str, default=None,
                        help='output file name prefix, xxx.py xxx.pth')
    args = parser.parse_args()
    
    torch_to_pytorch(args.model,args.output)
    

    【讨论】:

      猜你喜欢
      • 2019-08-04
      • 1970-01-01
      • 2022-10-13
      • 2023-01-31
      • 2018-10-05
      • 2020-08-25
      • 1970-01-01
      • 2017-06-11
      • 2018-11-25
      相关资源
      最近更新 更多