【问题标题】:Pytorch: How to optimize multiple variables with respect to multiple losses?Pytorch:如何针对多重损失优化多个变量?
【发布时间】:2020-08-10 13:03:35
【问题描述】:

我希望针对不同变量计算不同损失的梯度,然后将这些变量一起计算。

这是一个简单的例子,展示了我想要什么:

import torch as T
x = T.randn(3, requires_grad = True)
y = T.randn(4, requires_grad = True)
z = T.randn(5, requires_grad = True)

x_opt = T.optim.Adadelta([x])
y_opt = T.optim.Adadelta([y])
z_opt = T.optim.Adadelta([z])

for i in range(n_iter):
  x_opt.zero_grad()
  y_opt.zero_grad()
  z_opt.zero_grad()

  shared_computation = foobar(x, y, z)

  x_loss = f(x, y, z, shared_computation)
  y_loss = g(x, y, z, shared_computation)
  z_loss = h(x, y, z, shared_computation)

  x_loss.backward_with_respect_to(x)
  y_loss.backward_with_respect_to(y)
  z_loss.backward_with_respect_to(z)

  x_opt.step()
  y_opt.step()
  z_opt.step()

我的问题是我们如何在 PyTorch 中完成 backward_with_respect_to 部分?我只想要x 的渐变w.r.t。 x_loss 等。然后我希望所有优化器齐头并进(基于 xyz 的当前值)。

【问题讨论】:

    标签: optimization pytorch autograd


    【解决方案1】:

    我写了一个函数来做这件事。两个关键组成部分是 (1) 使用 retain_graph=True 来处理除对 .backward() 的最终调用之外的所有内容,以及 (2) 在每次调用 .backward() 后保存毕业生,并在最后在 .step()ing 之前恢复它们。

    def multi_step(losses, optms):
      # optimizers each take a step, with `optms[i]`'s variables being 
      # optimized w.r.t. `losses[i]`.
      grads = [None]*len(losses)
      for i, (loss, optm) in enumerate(zip(losses, optms)):
        retain_graph = i != (len(losses)-1)
        optm.zero_grad()
        loss.backward(retain_graph=retain_graph)
        grads[i] = [ 
              [ 
                p.grad+0 for p in group['params'] 
              ] for group in optm.param_groups
            ]
      for optm, grad in zip(optms, grads):
        for p_group, g_group in zip(optm.param_groups, grad):
          for p, g in zip(p_group['params'], g_group):
            p.grad = g
        optm.step()
    

    在问题中所述的示例代码中,multi_step 将按如下方式使用:

    for i in range(n_iter):
      shared_computation = foobar(x, y, z)
    
      x_loss = f(x, y, z, shared_computation)
      y_loss = g(x, y, z, shared_computation)
      z_loss = h(x, y, z, shared_computation)
    
      multi_step([x_loss, y_loss, z_loss], [x_opt, y_opt, z_opt])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-05-28
      • 2020-09-10
      • 2019-09-07
      • 1970-01-01
      • 1970-01-01
      • 2022-01-18
      • 2018-07-12
      • 2021-02-20
      相关资源
      最近更新 更多