我想您想在训练时交替执行这两项任务。
我还将假设您将在同一个批次中混合这两个任务。
你可以自定义一个返回的Dataset
class MixedDataset(Dataset):
# ...
def __getitem__(self, index):
# ... get data according to index
return img, seg, seg_flag, class, class_flag
对于医学图像,seg 将是一个虚拟掩码,seg_flag 将为零,而class 将成为目标类,class_flag 为 1。
另一方面,对于自然图像,seg 将是所需的分割掩码,seg_flag 为 1,而class 将是一个虚拟图像,class_flag 为零。
现在你可以运行你的训练代码了:
for i, (img, seg, seg_flag, class, class_flag) in train_loader:
opt.zero_grad()
pred_mask, pred_class = model(img) # predict both
loss_seg = seg_flag * dice_loss_fuction(pred_mask, seg)
loss_class = class_flag * cross_entropy_loss_function(pred_class, class)
(loss_seg + loss_class).backward()
opt.step()