【问题标题】:Pytorch CNN: Expected input to have 1 channel but got 60000 channels insteadPytorch CNN:预期输入有 1 个通道,但有 60000 个通道
【发布时间】:2021-12-26 17:15:48
【问题描述】:

在为 Fashion MNIST 数据集实现 NN 时,我收到以下错误:

RuntimeError: Given groups=1, weight of size [6, 1, 5, 5], expected input[1, 60000, 28, 28] to have 1 channels, but got 60000 channels instead

我推断 60000 是我的整个数据集的长度,但不知道为什么算法会给出这个错误。有人可以帮我解决这个问题吗?

我的数据集:

(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
train_data = []
test_data = []
train_data.append([X_train, y_train])
test_data.append([X_test, y_test])

trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100)
testloader = torch.utils.data.DataLoader(test_data, shuffle=True, batch_size=100)

我按以下顺序收到错误(根据堆栈跟踪):

      8     #making predictions
----> 9     y_pred = model(images)

     32     #first hidden layer
---> 33     x = self.conv1(x)

更新 1

添加行:

images = images.transpose(0, 1)

按照 Ivan 的指示转置图像,但现在出现错误:

RuntimeError: expected scalar type Byte but found Float

【问题讨论】:

    标签: python neural-network pytorch conv-neural-network mnist


    【解决方案1】:

    你的输入是(1, 60000, 28, 28),而它应该是(60000, 1, 28, 28)。您可以通过调换前两个轴来解决此问题:

    >>> x.transpose(0, 1)
    

    【讨论】:

    • 感谢@Ivan,我在调用模型函数之前添加了语句,但出现了新错误。请参考我的更新。
    • 这应该被选为正确答案。您收到的新错误与标题无关。
    • @Ashar 您必须将输入转换为浮点数:x.float()
    【解决方案2】:

    从外观上看,您使用了 TensorFlow 的数据集。我用过torchvison的数据集,效果很好

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torchvision.datasets import FashionMNIST
    from torchvision import transforms
    
    
    class Network(nn.Module):
      def __init__(self):
        super(Network,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
        self.fc1 = nn.Linear(in_features=12*4*4,out_features=120)
        self.fc2 = nn.Linear(in_features=120,out_features=60)
        self.fc3 = nn.Linear(in_features=60,out_features=40)
        self.out = nn.Linear(in_features=40,out_features=10)
      def forward(self,x):
        #input layer
        x = x
        #first hidden layer
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x,kernel_size=2,stride=2)
        #second hidden layer
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x,kernel_size=2,stride=2)
        #third hidden layer
        x = x.reshape(-1,12*4*4)
        x = self.fc1(x)
        x = F.relu(x)
        #fourth hidden layer
        x = self.fc2(x)
        x = F.relu(x)
        #fifth hidden layer
        x = self.fc3(x)
        x = F.relu(x)
        #output layer
        x = self.out(x)
        return x
    
    
    batch_size = 1000
    train_dataset = FashionMNIST(
        '../data', train=True, download=True, 
        transform=transforms.ToTensor())
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    test_dataset = FashionMNIST(
        '../data', train=False, download=True, 
        transform=transforms.ToTensor())
    testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    model = Network()
    
    losses = []
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    epochs = 10
    
    for i in range(epochs):
        batch_loss = []
        for j, (data, targets) in enumerate(trainloader):
            optimizer.zero_grad()
            ypred = model(data)
            loss = criterion(ypred, targets.reshape(-1))
            loss.backward()
            optimizer.step()
            batch_loss.append(loss.item())
        if i>10: 
            optimizer.lr = 0.0005
        losses .append(sum(batch_loss) / len(batch_loss))
        print('Epoch {}:\tloss {:.4f}'.format(i, losses [-1]))
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2023-01-14
      • 1970-01-01
      • 2022-06-10
      • 2021-03-11
      • 2021-04-19
      • 2020-01-27
      • 2019-04-24
      相关资源
      最近更新 更多