【问题标题】:'DataLoader' object does not support indexing“DataLoader”对象不支持索引
【发布时间】:2019-11-12 06:25:47
【问题描述】:

我通过设置 download=True 通过这个 pytorch api 下载了 ImageNet 数据集。但我无法遍历数据加载器。

错误提示“'DataLoader' 对象不支持索引”

trainset = torch.utils.data.DataLoader(
    datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
                      download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)

我尝试了一种简单的方法,我只是尝试运行以下命令,

trainloader[0]

在根目录下,模式为

root/  
    train/  
          n01440764/
          n01443537/ 
                   n01443537_2.jpg

官方网站上的文档没有说别的。 https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet

我做错了什么?

【问题讨论】:

  • 在您的示例中,您正在从DataLoader 创建DataLoader,这是一个错误还是您的真实代码?
  • 是的,这是真正的代码

标签: python computer-vision pytorch imagenet


【解决方案1】:

嗯,答案很简单(除了另一个答案中提到的错误)。

DataLoader 没有__getitem__ 方法(请自行查看in the source code)。

它用于对数据(或成批数据)进行迭代,而不是随机访问。如果你想访问特定元素,你应该使用torch.utils.data.Dataset,在你的情况下:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]

获取批次

如果您想获得一批,您可以对其进行迭代并在之后中断:

for batch in dataloader:
    print(batch) # or anything else you want to do
    break

DataLoader 以默认或指定方式创建随机索引(请参阅samplers),因此没有__getitem__,因为它对这个对象没有意义。

您也可以从 DataLoader 继承并创建自己的 __getitem__ 函数来做您想做的事情(虽然更复杂)。

完整示例

# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)

for batch in trainloader:
    print(batch)
    break

上面应该打印第一批里面的东西。

【讨论】:

  • 那我怎么像批量一样得到呢?
  • 澄清一下,根据 Anubhav Singh 的代码,“dataloader”是什么意思,是“trainset”还是“trainloader”?因为 trainloader 它不起作用!
  • 如果您从train_dataset 创建DataLoader,它就可以工作。 dataloader 指的是此类 DataLoader 类的实例。
  • 我真的很困惑。你能解释一下还是更新你的代码
  • 嗯,是的,也不是。我面临的问题是 img 是一个 PIL 文件,而调试器没有说出来。您的代码仍然无法正常运行。修复方法是添加“standard_transforms.ToTensor()”作为变换。 :)
【解决方案2】:

解决办法

input_transform = standard_transforms.Compose([
    transforms.Resize((255,255)), # to Make sure all the 
    transforms.CenterCrop(224),   # imgs are at the same size 
    transforms.ToTensor()
])  


# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
                             split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)


for batch_idx, data in enumerate(trainloader, 0):
    x, y = data 
    break

【讨论】:

    【解决方案3】:

    torch.utils.data.DataLoader() 的输入数据集应该是 torch.utils.data.Dataset 类型,而不是 torch.utils.data.DataLoader,这就是您在上面的代码中所做的。

    所以,你上面的代码应该是:

    trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', 
                                              split='train', 
                                              download=False)
    
    trainloader = torch.utils.data.DataLoader(trainset, 
                                              batch_size=1, 
                                              shuffle=False, 
                                              num_workers=1)
    

    更多详情,请查看官方火炬文档here

    【讨论】:

    • 是的,我看到了问题,并尝试了您的解决方案。当我执行“trainloader [0]”时,我仍然遇到相同的错误“'DataLoader' 对象不支持索引”
    • 虽然属实,但它并不能解决问题(更不用说重申评论的事实了)。
    猜你喜欢
    • 2014-04-23
    • 2017-09-20
    • 1970-01-01
    • 1970-01-01
    • 2013-08-23
    • 2018-01-07
    • 2013-06-23
    • 2019-04-18
    • 2017-12-28
    相关资源
    最近更新 更多