【问题标题】:how to flatten input in `nn.Sequential` in Pytorch如何在 Pytorch 的“nn.Sequential”中展平输入
【发布时间】:2019-05-25 23:30:54
【问题描述】:

如何在nn.Sequential 中展平输入

Model = nn.Sequential(x.view(x.shape[0],-1),
                     nn.Linear(784,256),
                     nn.ReLU(),
                     nn.Linear(256,128),
                     nn.ReLU(),
                     nn.Linear(128,64),
                     nn.ReLU(),
                     nn.Linear(64,10),
                     nn.LogSoftmax(dim=1))

【问题讨论】:

    标签: python neural-network artificial-intelligence pytorch


    【解决方案1】:

    您可以如下创建一个新模块/类,并在使用其他模块时按顺序使用它(调用Flatten())。

    class Flatten(torch.nn.Module):
        def forward(self, x):
            batch_size = x.shape[0]
            return x.view(batch_size, -1)
    

    参考:https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983

    编辑:Flatten 现在是火炬的一部分。见https://pytorch.org/docs/stable/nn.html?highlight=flatten#torch.nn.Flatten

    【讨论】:

    • 或者直接在forward方法中调用out = x.view(batch_size, -1)
    • @DanielMöller 再看问题,OP 想用nn.Sequential 来做
    • 知道了,你的答案很完美。
    【解决方案2】:

    定义为flatten method

    torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
    

    速度与view()相当,但reshape更快。

    import torch.nn as nn
    
    class Flatten(nn.Module):
        def forward(self, input):
            return input.view(input.size(0), -1)
    
    flatten = Flatten()
    
    t = torch.Tensor(3,2,2).random_(0, 10)
    print(t, t.shape)
    
    
    #https://pytorch.org/docs/master/torch.html#torch.flatten
    f = torch.flatten(t, start_dim=1, end_dim=-1)
    print(f, f.shape)
    
    
    #https://pytorch.org/docs/master/torch.html#torch.view
    f = t.view(t.size(0), -1)
    print(f, f.shape)
    
    
    #https://pytorch.org/docs/master/torch.html#torch.reshape
    f = t.reshape(t.size(0), -1)
    print(f, f.shape)
    

    速度检查

    # flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    # view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    # reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    

    如果我们使用上面的类

    flatten = Flatten()
    t = torch.Tensor(3,2,2).random_(0, 10)
    %timeit f=flatten(t)
    
    
    5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    

    这个结果表明创建一个类会比较慢。这就是为什么将张量向前展平更快的原因。我认为这是他们没有推广nn.Flatten的主要原因。

    所以我的建议是使用内锋来提高速度。像这样的:

    out = inp.reshape(inp.size(0), -1)
    

    【讨论】:

      【解决方案3】:

      你可以如下修改你的代码,

      Model = nn.Sequential(nn.Flatten(0, -1),
                           nn.Linear(784,256),
                           nn.ReLU(),
                           nn.Linear(256,128),
                           nn.ReLU(),
                           nn.Linear(128,64),
                           nn.ReLU(),
                           nn.Linear(64,10),
                           nn.LogSoftmax(dim=1))
      

      【讨论】:

        猜你喜欢
        • 2020-10-07
        • 1970-01-01
        • 2019-10-19
        • 2017-11-01
        • 2017-04-24
        • 2021-01-12
        • 2021-05-11
        • 2020-05-11
        • 2020-09-30
        相关资源
        最近更新 更多