【发布时间】:2021-02-01 00:31:42
【问题描述】:
我正在 CIFAR-10 数据集上训练一个神经网络(不管是哪一个)。我正在使用联邦学习:
- 我有 10 个模型,每个模型都可以访问自己的数据集部分。在每个时间步,每个模型使用自己的数据做一步,然后全局模型是模型的平均值(这个版本是基于this,但我尝试了很多选项):
def server_aggregate(server_model, client_models):
global_dict = server_model.state_dict()
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
server_model.load_state_dict(global_dict)
for model in client_models:
model.load_state_dict(server_model.state_dict())
- 具体来说,每台机器只能访问与单个类对应的数据。 IE。机器
0仅具有与0类对应的样本等。我这样做的方式如下:
def split_into_classes(full_ds, batch_size, num_classes=10):
class2indices = [[] for _ in range(num_classes)]
for i, y in enumerate(full_ds.targets):
class2indices[y].append(i)
datasets = [torch.utils.data.Subset(full_ds, indices) for indices in class2indices]
return [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in datasets]
问题。在训练期间,我可以看到我的联合训练损失减少了。但是,我从来没有看到我的测试损失/准确性提高(acc 总是在 10% 左右)。 此外,当我检查train/test datasets 的准确性时:
- 对于联合数据集,准确性有所提高。
- 对于测试数据集,准确度没有提高。
- (最令人惊讶)对于训练数据集,准确度没有提高。请注意,此数据集本质上与联邦数据集相同,但未拆分为类。检查码是following:
def epoch_summary(model, fed_loaders, true_train_loader, test_loader, frac):
with torch.no_grad():
train_len = 0
train_loss, train_acc = 0, 0
for train_loader in fed_loaders:
cur_loss, cur_acc, cur_len = true_results(model, train_loader, frac)
train_loss += cur_len * cur_loss
train_acc += cur_len * cur_acc
train_len += cur_len
train_loss /= train_len
train_acc /= train_len
true_train_loss, true_train_acc, true_train_len = true_results(model, true_train_loader, frac)
test_loss, test_acc, test_len = true_results(model, test_loader, frac)
print("TrainLoss: {:.4f} TrainAcc: {:.2f} TrueLoss: {:.4f} TrueAcc: {:.2f} TestLoss: {:.4f} TestAcc: {:.2f}".format(
train_loss, train_acc, true_train_loss, true_train_acc, test_loss, test_acc
), flush=True)
完整代码可以在here找到。似乎无关紧要的事情:
- 型号。对于 Resnet 模型和其他一些模型,我遇到了同样的问题。
- 我如何聚合模型。我试过用
state_dict或者直接操作model.parameters(),没有效果。 - 我如何学习模型。我试过用
optim.SGD或者直接更新param.data -= learning_rate * param.grad,没有效果。 - 计算图。我尝试将
.detach().clone()和with torch.no_grad()添加到所有可能的位置,但没有效果。
所以我怀疑问题出在联合数据本身(特别是考虑到奇怪的准确性结果)。可能是什么问题?
【问题讨论】:
-
“在训练过程中,我可以看到我的训练损失减少了。但是,我从来没有看到我的测试损失/准确性提高”是整个火车上的联邦模型/测试数据集?
-
我不确定我是否理解这个问题,但是可以。请注意,我根据完整的训练数据估计评估模型(因此它应该是完美的,但事实并非如此)。
-
@Ivan,实验表明问题出在批量标准化(当我使用密集模型或用身份层替换 BN 层时,它可以工作)。看起来你不应该这样平均。您知道在模型平均期间处理它们的正确方法是什么吗?
-
"machine 0 只有与 class 0 对应的样本" 我不明白你将如何为 1-class 训练模型如果机器
0只有0类的数据点,则分类任务。这对我来说没有意义。 -
我认为Local SGD Converges Fast and Communicates Little 涵盖了我提到的两点。
标签: python neural-network pytorch