【问题标题】:Iterating over subsets from torch.utils.data.random_split迭代来自 torch.utils.data.random_split 的子集
【发布时间】:2020-05-25 18:33:00
【问题描述】:

我目前正在加载一个包含 AI 训练数据的文件夹。子文件夹代表标签名称,其中包含相应的图像。这通过使用 pyTorch 的 ImageFolder 加载器效果很好。

def load_dataset():
    data_path = 'C:/example_folder/'

    train_dataset_manual = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )

    train_loader_manual = torch.utils.data.DataLoader(
        train_dataset_manual,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )

    return train_loader_manual

full_dataset = load_dataset()

现在我想将此数据集拆分为训练数据集和测试数据集。我为此使用了 random_split 函数:

training_data_size = 0.8

train_size = int(training_data_size * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

full_dataset 是torch.utils.data.dataloader.DataLoader 类型的对象。我可以用这样的循环遍历它:

for batch_idx, (data, target) in enumerate(full_dataset):
    print(batch_idx)

train_datasettorch.utils.data.dataset.Subset 类型的对象。如果我尝试遍历它,我会得到:

TypeError 'DataLoader' 对象不可下标:

for batch_idx, (data, target) in enumerate(train_dataset):
    print(batch_idx)

我如何循环遍历它?我对 Python 比较陌生。

谢谢!

【问题讨论】:

    标签: python loops pytorch


    【解决方案1】:

    您需要将random_split 应用于Dataset 而不是DataLoader。用于定义DataLoader 的数据集在DataLoader.dataset 成员中可用。

    例如你可以这样做

    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset.dataset, [train_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False)
    

    然后您可以按预期迭代train_loadertest_loader

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2017-02-22
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2015-08-12
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多