【问题标题】:pyTorch can backward twice without setting retain_graph=TruepyTorch 可以在不设置 retain_graph=True 的情况下向后退两次
【发布时间】:2019-02-27 00:39:07
【问题描述】:

pyTorch tutorial所示,

如果您甚至想在图表的某些部分进行两次倒退, 你需要在第一次传递时传入retain_graph = True。

但是,我发现以下代码 sn-p 实际上没有这样做。我正在使用 pyTorch-0.4

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
y.backward(torch.ones(2, 2)) # Note I do not set retain_graph=True
y.backward(torch.ones(2, 2)) # But it can still work!
print x.grad

输出:

tensor([[ 2.,  2.], 
        [ 2.,  2.]]) 

谁能解释一下?提前致谢!

【问题讨论】:

    标签: pytorch autograd


    【解决方案1】:

    在您的情况下它可以在没有 retain_graph=True 的情况下工作的原因是您有一个非常简单的图形,可能没有内部中间缓冲区,反过来不会释放任何缓冲区,因此无需使用 retain_graph=True

    但是当向您的图表添加额外的计算时,一切都会发生变化:

    代码:

    x = torch.ones(2, 2, requires_grad=True)
    v = x.pow(3)
    y = v + 2
    
    y.backward(torch.ones(2, 2))
    
    print('Backward 1st time w/o retain')
    print('x.grad:', x.grad)
    
    print('Backward 2nd time w/o retain')
    
    try:
        y.backward(torch.ones(2, 2))
    except RuntimeError as err:
        print(err)
    
    print('x.grad:', x.grad)
    

    输出:

    Backward 1st time w/o retain
    x.grad: tensor([[3., 3.],
                    [3., 3.]])
    Backward 2nd time w/o retain
    Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
    x.grad: tensor([[3., 3.],
                    [3., 3.]]).
    

    在这种情况下,额外的内部v.grad 将被计算,但torch 不存储中间值(中间梯度等),并且retain_graph=False v.grad 将在第一个backward 之后被释放。

    所以,如果你想第二次反向传播,你需要指定retain_graph=True 来“保留”图表。

    代码:

    x = torch.ones(2, 2, requires_grad=True)
    v = x.pow(3)
    y = v + 2
    
    y.backward(torch.ones(2, 2), retain_graph=True)
    
    print('Backward 1st time w/ retain')
    print('x.grad:', x.grad)
    
    print('Backward 2nd time w/ retain')
    
    try:
        y.backward(torch.ones(2, 2))
    except RuntimeError as err:
        print(err)
    print('x.grad:', x.grad)
    

    输出:

    Backward 1st time w/ retain
    x.grad: tensor([[3., 3.],
                    [3., 3.]])
    Backward 2nd time w/ retain
    x.grad: tensor([[6., 6.],
                    [6., 6.]])
    

    【讨论】:

      猜你喜欢
      • 2020-06-15
      • 1970-01-01
      • 2020-10-06
      • 1970-01-01
      • 1970-01-01
      • 2020-09-19
      • 2016-12-16
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多