【问题标题】:Pytorch modify dataset labelPytorch 修改数据集标签
【发布时间】:2018-12-13 09:39:54
【问题描述】:

这是一个代码 sn-p,用于从 pytorch transfer learning tutorial 加载图像作为数据集:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

这是数据集中的示例之一:

image_datasets['val'][0]:

(tensor([[[ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          ...,
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489]],

         [[ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          ...,
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286],
          [ 2.4286,  2.4286,  2.4286,  ...,  2.4286,  2.4286,  2.4286]],

         [[ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          ...,
          [ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400],
          [ 2.6400,  2.6400,  2.6400,  ...,  2.6400,  2.6400,  2.6400]]]), 0)

是否有任何方法(最佳实践)来更改数据集中的示例数据,例如将标签 0 更改为标签 1。以下不起作用:

image_datasets['val'][0] = (image_datasets['val'][0][0], 1)

【问题讨论】:

    标签: python deep-learning pytorch transfer-learning


    【解决方案1】:

    是的,虽然不是(很容易)以编程方式。标签来自torchvision.datasets.ImageFolder,反映了数据集的目录结构(如硬盘上所示)。首先,我怀疑您可能想知道目录名称为字符串。这没有很好的记录,但数据加载器有一个 classes 属性来存储这些。所以

    img, lbl = image_datasets['val'][0]
    directory_name = image_datasets['val'].classes[lbl]
    

    如果您希望始终返回这些而不是类 ID,则可以使用 target_transform api,如下所示:

    image_datasets['val'].target_transform = lambda id: image_datasets['val'].classes[id]
    

    这将使加载程序从现在开始返回字符串而不是 ID。如果您正在寻找更高级的东西,您可以从ImageFolderDatasetFolder 重新实现/继承并实现您自己的语义。您需要提供的唯一方法是__len____getitem__

    【讨论】:

      猜你喜欢
      • 2020-02-15
      • 2021-03-04
      • 2021-07-28
      • 2021-05-27
      • 2019-12-05
      • 2022-11-23
      • 2019-05-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多