【问题标题】:computing gradients for every individual sample in a batch in PyTorch在 PyTorch 中为批次中的每个样本计算梯度
【发布时间】:2019-05-16 19:18:45
【问题描述】:

我正在尝试实现一个差分私有随机梯度下降版本(例如,this),如下所示:

计算大小为 L 的批次中每个点的梯度,然后分别裁剪 L 个梯度中的每一个,然后将它们平均在一起,最后执行(嘈杂的)梯度下降步骤。

在 pytorch 中最好的方法是什么?

最好有一种方法可以同时计算批次中每个点的梯度:

x # inputs with batch size L
y #true labels
y_output = model(x)
loss = loss_func(y_output,y) #vector of length L
loss.backward() #stores L distinct gradients in each param.grad, magically

但如果做不到这一点,请分别计算每个梯度,然后在累积之前裁剪范数,但是

x # inputs with batch size L
y #true labels
y_output = model(x)
loss = loss_func(y_output,y) #vector of length L   
for i in range(loss.size()[0]):
    loss[i].backward(retain_graph=True)
    torch.nn.utils.clip_grad_norm(model.parameters(), clip_size)

先累积第 i 个梯度,然后进行剪辑,而不是在将其累积到梯度中之前进行剪辑。解决此问题的最佳方法是什么?

【问题讨论】:

    标签: python pytorch gradient-descent


    【解决方案1】:

    我认为在计算效率方面你不能比第二种方法做得更好,你正在失去 backward 中批处理的好处,这是事实。关于裁剪的顺序,autograd 将梯度存储在参数张量的.grad 中。一个粗略的解决方案是添加一个字典,如

    clipped_grads = {name: torch.zeros_like(param) for name, param in net.named_parameters()}
    

    像这样运行你的 for 循环

    for i in range(loss.size(0)):
        loss[i].backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(net.parameters())
        for name, param in net.named_parameters():
            clipped_grads[name] += param.grad / loss.size(0)
        net.zero_grad()
    
    for name, param in net.named_parameters():
        param.grad = clipped_grads[name]
    
    optimizer.step()
    

    我省略了很多 detachrequires_grad=False 和类似的业务,这些业务可能是使其按预期运行所必需的。

    上述方法的缺点是您最终为参数梯度存储了 2 倍的内存。原则上,您可以采用“原始”渐变,对其进行剪辑,添加到clipped_gradient,然后在没有下游操作需要时立即丢弃,而在这里您保留grad 中的原始值直到反向传递结束. 可能register_backward_hook 允许你这样做,如果你违反指南并实际修改 grad_input,但你必须与更熟悉 autograd 的人核实。

    【讨论】:

      【解决方案2】:

      This 包并行计算每个样本的梯度。所需的内存仍然是标准随机梯度下降的batch_size 倍,但由于并行化,它可以运行得更快。

      【讨论】:

      • 在哪里?我看到一个指向差异隐私库的链接,它可能会这样做,但在哪里以及如何?
      猜你喜欢
      • 1970-01-01
      • 2021-09-11
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-05-12
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多