【发布时间】:2020-04-13 08:55:49
【问题描述】:
见代码sn-p:
import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)
输出是tensor([0.]),但是
import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
y = x
else:
y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)
输出是None。
我很困惑为什么torch.where 的输出是tensor([0.])?
更新
import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b
(a[0, 0] * a[0, 1]).backward()
print(b.grad)
输出为tensor([2., 0.])。 (a[0, 0] * a[0, 1]) 与b[1] 没有任何关系,但b[1] 的梯度是0 而不是None。
【问题讨论】:
标签: python pytorch automatic-differentiation