【发布时间】:2019-12-23 15:50:00
【问题描述】:
我正在尝试为不平衡数据集(0 类 = 4000 个图像,1 类 = 大约 250 个图像)创建一个二进制 CNN 分类器,我想对其执行 5 折交叉验证。目前,我正在将训练集加载到 ImageLoader 中,该 ImageLoader 应用我的转换/增强(?)并将其加载到 DataLoader 中。但是,这会导致我的训练拆分和验证拆分都包含增强数据。
我最初应用离线转换(离线增强?)来平衡我的数据集,但从这个线程 (https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split) 看来,只增强训练集似乎是理想的。我还希望在仅增强训练数据上训练我的模型,然后在 5 折交叉验证中在非增强数据上对其进行验证
我的数据以 root/label/images 的形式组织,其中有 2 个标签文件夹(0 和 1)和图像分类到各自的标签中。
到目前为止我的代码
total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])
//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in splits.split(total_set):
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)
model.train()
//Model train/eval works but may be overpredict
我确定我在这段代码中做的不是最佳或错误,但我似乎找不到任何关于专门增加交叉验证中的训练拆分的文档!
任何帮助将不胜感激!
【问题讨论】:
标签: python deep-learning pytorch