【发布时间】:2021-01-12 17:37:51
【问题描述】:
我正在 PyTorch 中编写一个众所周知的问题 MNIST database of handwritten digits 的代码。我下载了训练和测试数据集(从主网站),包括标记的数据集。数据集格式为t10k-images-idx3-ubyte.gz,提取后为t10k-images-idx3-ubyte。我的数据集文件夹看起来像
MINST
Data
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
现在,我编写了如下代码来加载数据
def load_dataset():
data_path = "/home/MNIST/Data/"
xy_trainPT = torchvision.datasets.ImageFolder(
root=data_path, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
xy_trainPT, batch_size=64, num_workers=0, shuffle=True
)
return train_loader
我的代码显示Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp
我该如何解决这个问题,并且我还想检查我的图像是否已从数据集中加载(只有一个图包含前 5 张图像)?
【问题讨论】:
标签: python-3.x machine-learning deep-learning pytorch