【发布时间】:2021-07-30 14:16:51
【问题描述】:
以下是我的玩具示例
import torch
x = torch.tensor(3.0, requires_grad = True)
y = x**2
y.backward(retain_graph = True)
print(x.grad)
x = x + 4
y.backward(retain_graph = True)
print(x.grad)
第一个打印打印 x 的渐变,而第二个打印什么也不打印。为什么 x 被 x = x + 4 更新后 x 的梯度消失了?谢谢。
新增问题:
下面的代码可以做我想做的事,它迭代地更新 x。但是,每次更新时我都需要添加 x.requires_grad = True 。不使用 x.requires_grad = True 有没有更好的方法?谢谢。
x = torch.tensor(3.0, requires_grad = True)
y = x**2
y.backward(retain_graph = True)
with torch.no_grad():
x = x + x.grad
x.requires_grad = True
y = x**2
y.backward(retain_graph = True)
print(x.grad)
更新:我的解决方案
x = torch.tensor(3.0, requires_grad = True)
y = x**2
y.backward(retain_graph = True)
print(x.grad)
x.data = x.data + x.grad.data
x.grad.zero_()
y = x**2
y.backward(retain_graph = True)
print(x.grad)
代码的结果是
tensor(6.)
tensor(18.)
,这正是我想要的。谢谢。
【问题讨论】: