【发布时间】:2019-01-25 03:09:02
【问题描述】:
我正在研究一个包含 31 个类(Office 数据集)的图像分类器。每个类都有一个文件夹。我有一个使用 PyTorch 编写的 python 脚本,它使用 datasets.ImageFolder 加载数据集并为每个图像分配一个标签,然后进行训练。这是我用于加载数据的代码 sn-p:
from torchvision import datasets, transforms
import torch
def load_training(root_path, dir, batch_size, kwargs):
transform = transforms.Compose(
[transforms.Resize([256, 256]),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
data = datasets.ImageFolder(root=root_path + dir, transform=transform)
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
return train_loader
代码获取每个文件夹,为该文件夹中的所有图像分配相同的标签。有没有办法找到哪个标签分配给哪个图像/图像文件夹?
【问题讨论】:
标签: python-3.x image-processing machine-learning pytorch