【问题标题】:Creating reduced Dataset from existing Torchvision Dataset从现有 Torchvision 数据集创建精简数据集
【发布时间】:2019-01-31 19:51:16
【问题描述】:

我们都知道常见的 MNIST 数据集,包含在 torchvision.datasets 包中。想象一下,我想创建一个仅包含 10 的数据集的简化版本,以仅对这两个数字进行分类,而不是对所有 10 个值进行分类。

我已经看到可以在继承所需数据集的类中创建自定义数据集,例如__getitem__,它返回给定索引处的项目。所以我这样做了:

class MNIST01(MNIST):
    def __getitem__(self, idx):
        image, label = super().__getitem__(idx)
        if label.item() <= 1:
            return image, label
        else:
            return None

问题是我似乎无法返回 None 值,因为它需要“包含张量、数字、字典或列表;找到类 'NoneType'”。

有没有一种简单的方法可以以类似的方式轻松获得此数据集的简化版本?

【问题讨论】:

    标签: python dataset torchvision


    【解决方案1】:

    我终于设法处理了 NoneType 问题。保留问题中定义的函数。

    class MNIST01(MNIST):
        def __getitem__(self, idx):
            features, target = super(MNIST01, self).__getitem__(idx)
            if target.item() <= 1:
                return features, target
    

    我们现在需要为我们的数据加载器定义一个自定义的collate functioncollate_fn,它处理样本列表以形成一个批次。在这个函数中,我们可以应用过滤器来处理Nonevalues并忽略它们。

    from torch.utils.data.dataloader import default_collate
    
    def filter_collate(batch):
        batch = list(filter(lambda x: x is not None, batch))
        return default_collate(batch)
    

    那么我们只需要将这个函数传递给DataLoader

    from torch.utils.data import DataLoader
    
    train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)
    test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)
    

    第 2 版

    比第一个容易得多,避免了访问数据时的一些问题。只需从MNIST 类的实例化中直接过滤train_datatrain_label 属性(以及对应的测试集)。

    train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]
    train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]
    

    【讨论】:

      猜你喜欢
      • 2021-06-20
      • 2021-08-21
      • 2021-11-05
      • 2018-12-06
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多