【问题标题】:what is the difference between if-else statement and torch.where in pytorch?pytorch中if-else语句和torch.where有什么区别?
【发布时间】:2020-04-13 08:55:49
【问题描述】:

见代码sn-p:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

输出是tensor([0.]),但是

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
    y = x
else:
    y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

输出是None

我很困惑为什么torch.where 的输出是tensor([0.])

更新

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b

(a[0, 0] * a[0, 1]).backward()
print(b.grad)

输出为tensor([2., 0.])(a[0, 0] * a[0, 1])b[1] 没有任何关系,但b[1] 的梯度是0 而不是None

【问题讨论】:

    标签: python pytorch automatic-differentiation


    【解决方案1】:

    基于跟踪的 AD,如 pytorch,由 tracking 工作。您无法跟踪不是库拦截的函数调用的内容。通过使用这样的if 语句,xy 之间没有任何联系,而wherexy 在表达式树中被链接。

    现在,对于差异:

    • 在第一个 sn-p 中,0 是函数 x ↦ x > 0 ? x : 2 在点 -1 处的正确导数(因为负侧是常数)。
    • 在第二个 sn-p 中,正如我所说,xy 没有任何关系(在else 分支中)。因此,y给定x的导数是未定义的,表示为None

    (你甚至可以在 Python 中做这样的事情,但这需要更复杂的技术,比如源代码转换。我不认为 pytorch 可以做到。)

    【讨论】:

    • 我猜0 渐变等价于None,请参阅我的更新。
    • 不,还是一样的原理。在您的新示例中, (a[0, 0] * a[0, 1]) 被认为是整个 b 的函数。您仅通过与 b[1] 相关的恒定部分进行反向传播。建议大家熟悉一下一般的AD系统的实现细节,这样会更容易看懂。
    • 有没有关于AD系统实现细节的教程,尤其是就地操作,存在很多怪现象。
    • 我有一组参考资料 here,但它非常关注 Julia。可变性很少被处理,因为它使事情变得困难。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-02-06
    • 2020-04-04
    • 1970-01-01
    • 1970-01-01
    • 2014-07-22
    • 2019-05-06
    相关资源
    最近更新 更多