【问题标题】:How do you load MNIST images into Pytorch DataLoader?如何将 MNIST 图像加载到 Pytorch DataLoader 中?
【发布时间】:2018-10-07 16:51:52
【问题描述】:

用于数据加载和处理的 pytorch 教程非常具体地针对一个示例,有人可以帮助我了解更通用的简单图像加载功能应该是什么样的吗?

教程:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

我的数据:

我将 MINST 数据集作为 jpg 格式保存在以下文件夹结构中。 (我知道我可以只使用数据集类,但这纯粹是为了看看如何将简单的图像加载到没有 csv 或复杂功能的 pytorch 中)。

文件夹名称是标签,图像是 28x28 png 的灰度,不需要转换。

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    如果您使用的是 mnist,那么 pytorch 中已经有一个通过 torchvision 进行的预设。
    你可以这样做

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import pandas as pd
    
    transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
    mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                          shuffle=True, num_workers=2)
    

    如果你想推广到图像目录(与上面相同的导入),你可以这样做

    class mnistmTrainingDataset(torch.utils.data.Dataset):
    
        def __init__(self,text_file,root_dir,transform=transformMnistm):
            """
            Args:
                text_file(string): path to text file
                root_dir(string): directory with all train images
            """
            self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
            self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
            self.root_dir = root_dir
            self.transform = transform
    
        def __len__(self):
            return len(self.name_frame)
    
        def __getitem__(self, idx):
            img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
            image = Image.open(img_name)
            image = self.transform(image)
            labels = self.label_frame.iloc[idx, 0]
            #labels = labels.reshape(-1, 2)
            sample = {'image': image, 'labels': labels}
    
            return sample
    
    
    mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                       root_dir = 'Downloads/mnist_m/mnist_m_train')
    
    mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)
    

    然后你可以像这样迭代它:

    for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
        print("training sample for mnist-m")
        print(i_batch,sample_batched['image'],sample_batched['labels'])
    

    pytorch 用于图像数据集加载的泛化方法有很多,我知道的方法是子类化torch.utils.data.dataset

    【讨论】:

    • 加载同一个文件并独立访问它的两列两次是非常低效的!
    • 是的。而是使用一个数据框。在read_csv 中使用index_col=False 加载它以获得数字索引。然后在__getitem__ 中使用self.df.at[idx, "filename"]self.df.at[idx, "label"]
    【解决方案2】:

    这是我为 pytorch 0.4.1 所做的(在 1.3 中应该仍然可以使用)

    def load_dataset():
        data_path = 'data/train/'
        train_dataset = torchvision.datasets.ImageFolder(
            root=data_path,
            transform=torchvision.transforms.ToTensor()
        )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=64,
            num_workers=0,
            shuffle=True
        )
        return train_loader
    
    for batch_idx, (data, target) in enumerate(load_dataset()):
        #train network
    

    【讨论】:

    • 您的 load_dataset() 函数中如何指定类标签?
    • ImageFolder 根据类文件夹生成:pytorch.org/docs/stable/torchvision/…
    • 对于 MNIST 可能需要使用“transforms.Grayscale()”:test_dataset = torchvision.datasets.ImageFolder( root=data_path, transform=transforms.Compose([transforms.Grayscale(), transforms .ToTensor()]) )
    猜你喜欢
    • 2019-05-02
    • 1970-01-01
    • 1970-01-01
    • 2021-02-18
    • 2020-03-10
    • 2022-01-24
    • 2019-07-29
    • 2020-08-07
    • 2019-07-23
    相关资源
    最近更新 更多