【问题标题】:trying to import png images to torchvision试图将 png 图像导入到 torchvision
【发布时间】:2020-11-29 13:50:53
【问题描述】:

我正在尝试导入用于 torch 和 torchvision 的图像。但我收到此错误:

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "c:\python38\lib\site-packages\torch\utils\data\_utils\worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "c:\python38\lib\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "c:\python38\lib\site-packages\torch\utils\data\_utils\collate.py", line 79, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "c:\python38\lib\site-packages\torch\utils\data\_utils\collate.py", line 79, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "c:\python38\lib\site-packages\torch\utils\data\_utils\collate.py", line 81, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

根据这篇文章,我将它们转换为张量:

https://discuss.pytorch.org/t/typeerror-default-collate-batch-must-contain-tensors-numpy-arrays-numbers-dicts-or-lists-found-class-imageio-core-util-array/62667

这是我的代码:

import torch
import torchvision
import torchvision.transforms
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor()
])

dataset = torchvision.datasets.ImageFolder('datasets')

dataloader = torch.utils.data.DataLoader(dataset,
                                          batch_size=16,
                                          shuffle=True,
                                          num_workers=12)

tensor_dataset = []

for i, data in enumerate(dataloader, 0):
    Tensor = torch.tensor(data)
    tensor_dataset.append(Tensor.flatten)

最后一部分来自https://github.com/TerragonDE/PyTorch,但我没有成功。我要加载的数据来自这里:

http://www.cvlibs.net/datasets/kitti/

我该如何解决这个问题?

更新:

感谢@trialNerror,但现在我收到此错误:

ValueError                                Traceback (most recent call last)
<ipython-input-6-aa72392b67e8> in <module>
      1 for i, data in enumerate(dataloader, 0):
----> 2     Tensor = torch.tensor(data)
      3     tensor_dataset.append(Tensor.flatten)

ValueError: only one element tensors can be converted to Python scalars

这是我目前发现的,但不知道如何应用它:

https://discuss.pytorch.org/t/pytorch-autograd-grad-only-one-element-tensors-can-be-converted-to-python-scalars/56681

更新 2:

我最终没有使用数据加载器的原因是因为我最终得到了这个错误:

num_epochs = 10
loss_values = list()

for epoch in range(1, num_epochs):
    for i, data in enumerate(train_array, 0):     
        outputs = model(data.unsqueeze(0))
        loss = criterion(outputs,data.unsqueeze(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        print('Epoch - %d, loss - %0.5f '%(epoch, loss.item()))
        loss_values.append(loss.item())

torch.Size([1, 16, 198, 660]) 火炬尺寸([1, 32, 97, 328]) torch.Size([1, 1018112])

RuntimeError                              Traceback (most recent call last)
<ipython-input-106-5e6fa86df079> in <module>
      4 for epoch in range(1, num_epochs):
      5     for i, data in enumerate(train_array, 0):
----> 6         outputs = model(data.unsqueeze(0))
      7         loss = criterion(outputs,data.unsqueeze(0))
      8 

c:\python38\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-90-467a3f84a03f> in forward(self, x)
     29         print(out.shape)
     30 
---> 31         out = self.fc(out)
     32         print(out.shape)
     33 

c:\python38\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

c:\python38\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
     85 
     86     def forward(self, input):
---> 87         return F.linear(input, self.weight, self.bias)
     88 
     89     def extra_repr(self):

c:\python38\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
   1608     if input.dim() == 2 and bias is not None:
   1609         # fused op is marginally faster
-> 1610         ret = torch.addmm(bias, input, weight.t())
   1611     else:
   1612         output = input.matmul(weight.t())

RuntimeError: size mismatch, m1: [1 x 1018112], m2: [512 x 10] at C:\w\b\windows\pytorch\aten\src\TH/generic/THTensorMath.cpp:41

我意识到如果你有 m1: [a * b] 和 m2: [c * d] 那么 b 和 c 必须是相同的值,但我不确定,调整图像大小的最佳方法是什么?

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    您的 transform 变量未使用,应将其传递给 Dataset 构造函数:

    `dataset = torchvision.datasets.ImageFolder('datasets', transform=transform)`
    

    因此,ToTensor 永远不会应用于您的数据,因此它们仍然是 PIL 图像,而不是张量。

    【讨论】:

      【解决方案2】:

      请注意,我想将目录中的所有 PNG 图像自动加载为 pytorch 张量。我以前看过这样的帖子(以及许多其他网页):

      Loading a huge dataset batch-wise to train pytorch

      但我最终在目录中的所有图像上使用Image.open,而不是torch DataLoader

      import numpy as np
      from PIL import Image
      import matplotlib.pyplot as plt
      
      image = Image.open("datasets/image_02/data/my_image.png").convert('RGB')
      
      import torch
      import torch.nn as nn
      import torch.nn.functional as F
      from torchvision import transforms as tf
      
      transforms = tf.Compose([tf.Resize(400), 
                              tf.ToTensor()])
      
      img_tensor = transforms(image)
      

      【讨论】:

      • 我不确定我是否理解这里。此代码与您的问题完全不同,您实际使用的是哪一个?为什么你会选择不使用 pytorch 数据集/数据加载器并最终得到这个代码?
      • 最初我尝试使用 torch.utils.data.DataLoader 但由于我在问题中提到的错误而无法使其正常工作,因此我选择使用 Image.open。
      猜你喜欢
      • 2020-01-05
      • 2017-12-31
      • 2018-02-06
      • 2022-12-25
      • 1970-01-01
      • 2013-08-20
      • 1970-01-01
      • 1970-01-01
      • 2012-02-20
      相关资源
      最近更新 更多