【发布时间】:2020-05-25 18:33:00
【问题描述】:
我目前正在加载一个包含 AI 训练数据的文件夹。子文件夹代表标签名称,其中包含相应的图像。这通过使用 pyTorch 的 ImageFolder 加载器效果很好。
def load_dataset():
data_path = 'C:/example_folder/'
train_dataset_manual = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader_manual = torch.utils.data.DataLoader(
train_dataset_manual,
batch_size=1,
num_workers=0,
shuffle=True
)
return train_loader_manual
full_dataset = load_dataset()
现在我想将此数据集拆分为训练数据集和测试数据集。我为此使用了 random_split 函数:
training_data_size = 0.8
train_size = int(training_data_size * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
full_dataset 是torch.utils.data.dataloader.DataLoader 类型的对象。我可以用这样的循环遍历它:
for batch_idx, (data, target) in enumerate(full_dataset):
print(batch_idx)
train_dataset 是 torch.utils.data.dataset.Subset 类型的对象。如果我尝试遍历它,我会得到:
TypeError 'DataLoader' 对象不可下标:
for batch_idx, (data, target) in enumerate(train_dataset):
print(batch_idx)
我如何循环遍历它?我对 Python 比较陌生。
谢谢!
【问题讨论】: