【发布时间】:2022-01-19 23:17:58
【问题描述】:
python 新手。如何修改类以使用字符串过滤文件夹中的文件。现在它返回文件夹中的所有文件,其中可能是数百万个项目。以下工作但是我想隔离包含特定字符串的文件,例如,隔离所有包含'v_1234_frame'的文件:
# Image loader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
image_dataset = utils.ImageFolderWithPaths(folder_containing_the_content_folder, transform=transform)
image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size)
有效的类需要修改以过滤包含“v_1234_frame”的文件名:
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths.
Extends torchvision.datasets.ImageFolder()
Reference: https://discuss.pytorch.org/t/dataloader-filenames-in-each-batch/4212/2
"""
# override the __getitem__ method. this is the method dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (*original_tuple, path)
return tuple_with_path
我正在学习 python,但似乎无法提出解决方案。希望您可以帮助/建议更改类或调用方法。
【问题讨论】:
标签: python python-3.x pytorch-dataloader