【问题标题】:Dimension error in implementation of a convolutional network卷积网络实现中的维​​度错误
【发布时间】:2020-06-06 01:13:52
【问题描述】:

我试图了解为什么我的分类器存在维度问题。这是我的代码:

class convnet(nn.Module):

    def __init__(self, num_classes=1000):
        super(convnet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride = 2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride = 2), #stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride = 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(576, 128),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Linear(64,num_classes),
            nn.Softmax(),
       )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x,1) #x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x


def neuralnet(num_classes,**kwargs):
    model = convnet(**kwargs)
    return model

所以我的问题是:预期的 4D 输入(得到 2D 输入)

我很确定错误是由 flatten 命令引起的,但是我真的不明白为什么分类器具有完全密集的连接。如果有人知道我哪里出错了,那将非常有帮助!

谢谢

【问题讨论】:

  • 你的错误指向哪一行?
  • 它指向 x= self.classifier(x)
  • 你的意思是x = self.classifier(x)
  • 提示:总是发布完整的堆栈跟踪
  • 下次我会记住这一点!谢谢(:

标签: pytorch conv-neural-network


【解决方案1】:

扁平化后,分类器的输入有2维(大小:[batch_size, 576]),因此第一个线性层的输出也将有2维(大小: [batch_size, 128])。然后将该输出传递给nn.BatchNorm2d,这要求其输入具有 4 个维度(大小:[batch_size, channels, height, width])。

如果您想在 2D 输入上使用批量规范,您需要使用 nn.BatchNorm1d,它接受 3D 输入(大小:[batch_size, channels, length])或 2D输入(大小:[batch_size, length])。

self.classifier = nn.Sequential(
    nn.Linear(576, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 64),
    nn.ReLU(inplace=True),
    nn.BatchNorm1d(64),
    nn.Linear(64,num_classes),
    nn.Softmax(),
)

【讨论】:

  • 我明白了!感谢您的详细回复!
猜你喜欢
  • 2016-07-17
  • 1970-01-01
  • 2017-07-31
  • 2016-12-04
  • 1970-01-01
  • 2020-07-10
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多