【问题标题】:Running through a dataloader in Pytorch using Google Colab使用 Google Colab 在 Pytorch 中运行数据加载器
【发布时间】:2018-09-27 11:07:51
【问题描述】:

我正在尝试使用 Pytorch 对猫狗图像数据集进行分类。在我的代码中,到目前为止,我正在下载数据并进入文件夹 train,其中有两个文件夹,分别称为“cats”和“dogs”。然后我尝试将此数据加载到数据加载器中并迭代批次,但它给了我一些我在迭代步骤中不理解的错误。

因为它是 Google Colabs,所以我在其中有用于下载数据和安装库的代码。到目前为止,对我的代码的任何其他建议也将不胜感激。

!pip install torch
!pip install torchvision

from __future__ import print_function, division
import os
import torch
import pandas as pd
import numpy as np
# For showing and formatting images
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# For importing datasets into pytorch
import torchvision.datasets as dataset

# Used for dataloaders
import torch.utils.data as data

# For pretrained resnet34 model
import torchvision.models as models

# For optimisation function
import torch.nn as nn
import torch.optim as optim


!wget http://files.fast.ai/data/dogscats.zip
!unzip dogscats.zip    

batch_size = 256

train_raw = dataset.ImageFolder(PATH+"train", transform=transforms.ToTensor())
train_loader = data.DataLoader(train_raw, batch_size=batch_size, shuffle=True)

for batch_idx, (data, target) in enumerate(train_loader):
  print("Data: ", batch_idx)

错误出现在最后几行,如下:

RuntimeErrorTraceback (most recent call last)
<ipython-input-66-c32dd0c1b880> in <module>()
----> 1 for batch_idx, (data, target) in enumerate(train_loader):
      2   print("Data: ", batch_idx)
      3 

/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.pyc in __next__(self)
    257         if self.num_workers == 0:  # same-process loading
    258             indices = next(self.sample_iter)  # may raise StopIteration
--> 259             batch = self.collate_fn([self.dataset[i] for i in indices])
    260             if self.pin_memory:
    261                 batch = pin_memory_batch(batch)

/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.pyc in default_collate(batch)
    133     elif isinstance(batch[0], collections.Sequence):
    134         transposed = zip(*batch)
--> 135         return [default_collate(samples) for samples in transposed]
    136 
    137     raise TypeError((error_msg.format(type(batch[0]))))

/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.pyc in default_collate(batch)
    110             storage = batch[0].storage()._new_shared(numel)
    111             out = batch[0].new(storage)
--> 112         return torch.stack(batch, 0, out=out)
    113     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
    114             and elem_type.__name__ != 'string_':

/usr/local/lib/python2.7/dist-packages/torch/functional.pyc in stack(sequence, dim, out)
     62     inputs = [t.unsqueeze(dim) for t in sequence]
     63     if out is None:
---> 64         return torch.cat(inputs, dim)
     65     else:
     66         return torch.cat(inputs, dim, out=out)

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 400 and 487 in dimension 2 at /pytorch/torch/lib/TH/generic/THTensorMath.c:2897

谢谢

【问题讨论】:

    标签: python-3.x image deep-learning image-recognition pytorch


    【解决方案1】:

    我认为主要问题是图像大小不同。我可能以其他方式理解 ImageFolder,但是,如果目录结构与 pytorch 中指定的一样,我认为您不需要图像标签,并且 pytorch 会为您找出标签。 我还会在您的转换中添加更多内容,以自动调整文件夹中每个图像的大小,例如:

       normalize = transforms.Normalize(
                            mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]
                            )
       transform = transforms.Compose(
            [transforms.ToTensor(),transforms.Resize((224,224)),
             normalize])
    

    您还可以使用其他技巧来使您的 DataLoader 更快,例如添加 batch_size 和 cpu worker 的数量,例如:

        testloader = DataLoader(testset, batch_size=16,
                             shuffle=False, num_workers=4)
    

    我认为这将使您的流水线速度更快。

    【讨论】:

    • 哇,谢谢 Manoj。昨晚我按照自己的方式运行了这个,几个小时后,大麦通过了所有的训练数据,损失让大麦移动了。现在我实施了您的更改,几分钟后它已经达到了 8%,并且在几个小时和更多数据之后,损失已经急剧下降到比昨晚 my'n 少。如果你不介意的话,还有两个问题。 1) 你是如何得到 Normalize 中的均值和标准差的值 2) 你是如何得出 batch_size 的。
    • 我找到了以下文档,但它没有解释如何为其派生值。 pytorch.org/docs/master/torchvision/transforms.html
    • 这些值是在 pytorch.org/docs/master/torchvision/models.html 的文档中为预训练的 pytorch 模型提供的。对于批量大小,您只需进行试验,但通常批量 256 个图像对于 GPU 来说是可以的。
    【解决方案2】:

    我认为我对 Manoj Acharya 的评论是错误的,问题在于将 batch_size 放入数据加载器中。我阅读了以下来源,您似乎无法将不同尺寸的图像一起批处理:

    https://medium.com/@yvanscher/pytorch-tip-yielding-image-sizes-6a776eb4115b

    所以在我的代码中更改数据变量 Manoj 指出我将 batch_size 更改为 1 并且程序停止失败。我想分批放置,所以我添加了进一步的变换 CenterCrop() 以将所有图像的大小调整为相同的大小。以下是我的新代码:

    !pip install torch
    !pip install torchvision
    
    from __future__ import print_function, division
    import os
    import torch
    import pandas as pd
    import numpy as np
    # For showing and formatting images
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg    
    # For importing datasets into pytorch
    import torchvision.datasets as dataset    
    # Used for dataloaders
    from torch.utils.data import DataLoader
    # For pretrained resnet34 model
    import torchvision.models as models    
    # For optimisation function
    import torch.nn as nn
    import torch.optim as optim    
    # For turning data into tensors
    import torchvision.transforms as transforms
    
    !wget http://files.fast.ai/data/dogscats.zip
    !unzip dogscats.zip
    
    batch_size = 256
    sz = 224
    
    train_raw = dataset.ImageFolder(PATH+"train", transform=transforms.Compose([transforms.CenterCrop(sz),transforms.ToTensor()]))
    train_loader = DataLoader(train_raw,batch_size=batch_size, shuffle=True)
    
    for batch_idx, (data, target) in enumerate(train_loader):
      print("Data: ", batch_idx)
    

    谢谢

    【讨论】:

      【解决方案3】:

      我在您的代码中看到两个问题,首先您将 import torch.utils.data 作为数据导入,然后再次在数据加载器中替换它。请将导入的模块和您的变量名称保存在单独的命名空间中。我认为这个错误可能是因为 dataloder(images) 和标签返回的数据大小不同。如您所见,连接存在错误,因为第一个维度即。标签大小和文件夹中的图像数量不匹配。希望这会有所帮助。

      【讨论】:

      • 感谢您的回复 Manoj Acharya,我已经解决了将 torch.utils.data 重命名为 data_util 的第一个问题,但第二个问题有点卡住了。我认为我不太了解 dataset.ImageFolder 的工作原理。在我给它的路径中有两个文件夹狗和猫,每个文件夹都充满了动物的图像。所以我不明白标签如何与图像不匹配,因为图像是根据它所在的文件夹标记的。dataset.ImageFolder 是否选择了这个?还是我需要另一个标签文件作为 .csv 或其他文件?
      猜你喜欢
      • 2018-09-10
      • 1970-01-01
      • 2020-08-21
      • 2019-01-02
      • 2019-12-29
      • 2019-05-31
      • 1970-01-01
      • 2020-03-20
      • 1970-01-01
      相关资源
      最近更新 更多