【问题标题】:Augmenting only the training set in K-folds cross validation仅在 K-folds 交叉验证中增加训练集
【发布时间】:2019-12-23 15:50:00
【问题描述】:

我正在尝试为不平衡数据集(0 类 = 4000 个图像,1 类 = 大约 250 个图像)创建一个二进制 CNN 分类器,我想对其执行 5 折交叉验证。目前,我正在将训练集加载到 ImageLoader 中,该 ImageLoader 应用我的转换/增强(?)并将其加载到 DataLoader 中。但是,这会导致我的训练拆分和验证拆分都包含增强数据。

我最初应用离线转换(离线增强?)来平衡我的数据集,但从这个线程 (https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split) 看来,只增强训练集似乎是理想的。我还希望在仅增强训练数据上训练我的模型,然后在 5 折交叉验证中在非增强数据上对其进行验证

我的数据以 root/label/images 的形式组织,其中有 2 个标签文件夹(0 和 1)和图像分类到各自的标签中。

到目前为止我的代码

total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])

//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in splits.split(total_set):
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)

model.train()
//Model train/eval works but may be overpredict 

我确定我在这段代码中做的不是最佳或错误,但我似乎找不到任何关于专门增加交叉验证中的训练拆分的文档!

任何帮助将不胜感激!

【问题讨论】:

    标签: python deep-learning pytorch


    【解决方案1】:

    一种方法是实现一个包装数据集类,该类将转换应用于 ImageFolder 数据集的输出。例如

    class WrapperDataset:
        def __init__(self, dataset, transform=None, target_transform=None):
            self.dataset = dataset
            self.transform = transform
            self.target_transform = target_transform
    
        def __getitem__(self, index):
            image, label = self.dataset[index]
            if self.transform is not None:
                image = self.transform(image)
            if self.target_transform is not None:
                label = self.target_transform(label)
            return image, label
    
        def __len__(self):
            return len(self.dataset)
    

    然后,您可以通过使用不同的转换包装更大的数据集,在您的代码中使用它。

    total_set = datasets.ImageFolder(ROOT)
    
    # Eventually I plan to run cross-validation as such:
    splits = KFold(cv = 5, shuffle = True, random_state = 42)
    
    for train_idx, valid_idx in splits.split(total_set):
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)
    
        train_loader = torch.utils.data.DataLoader(
            WrapperDataset(total_set, transform=data_transforms['train_transforms']),
            batch_size=32, sampler=train_sampler)
        valid_loader = torch.utils.data.DataLoader(
            WrapperDataset(total_set, transform=data_transforms['valid_transforms']),
            batch_size=32, sampler=valid_sampler)
    
        # train/validate now
    

    我没有测试过这段代码,因为我没有你的完整代码/模型,但概念应该很清楚。

    【讨论】:

    • 感谢您的回复。我尝试实现您的想法,我认为它接近为我的代码工作。当我在训练时尝试迭代我的 train_loader 时,我得到一个 TypeError "img should be PIL Image. Got " 与你的类 WrapperDataset 相关的 "image = self.transform(image)" (代码会be:对于输入,train_loader中的标签:#train等
    • 就我的问题而言,我在“image = self.transform(image)”之前添加了“image = transforms.ToPILImage()(image)”,这解决了错误。再次感谢您的帮助!
    • 有趣,我不知道为什么会这样,因为ImageDataset 默认应该返回 PIL 图像。您在定义 total_set 时删除了转换,对吗?
    • 你说得对,它工作正常。我正在使用 ipython 笔记本,可能忘记重新初始化 total_set。哇!
    猜你喜欢
    • 2018-05-03
    • 2016-02-25
    • 2021-12-10
    • 2016-05-12
    • 2020-09-14
    • 2022-11-08
    • 2011-12-16
    • 2019-05-18
    • 2016-01-29
    相关资源
    最近更新 更多