【问题标题】:PyTorch gradient calculation gives unexpected resultsPyTorch 梯度计算给出了意想不到的结果
【发布时间】:2020-07-19 17:17:33
【问题描述】:

以下代码旨在解释 PyTorch 梯度计算的工作原理,IMO 应该返回权重矩阵,但它没有:

# the code calculates T x W + B ---> K1
# compute the mean of K1 --> km
# compute the gradient of km relative to T
#
import torch
torch.manual_seed(0)
t = torch.rand(2,3)
w = torch.rand(3,4)
b = torch.rand(1,4)
#
k1 = torch.mm(t, w) + b
#
#torch.set_grad_enabled(True)
print('k1_grad_fn ',k1.grad_fn)
#
print('t grad ',t.grad)
#
#
km = k1.mean()
km.requires_grad_(True)
print('k1 mean=',km)
km.backward()
print('t grad ',t.grad)
print('k1 grad ',k1.grad_fn)

结果是:

t grad  None
k1 mean= tensor(1.0396, requires_grad=True)
t grad  None
k1 grad  None```

【问题讨论】:

    标签: python pytorch gradient matrix-multiplication autograd


    【解决方案1】:

    不奇怪,应该是indicated by Autograd Mechanics in PyTorch documentation

    在子图中永远不会执行向后计算,其中所有 张量不需要梯度。

    设置km.required_grad_(True) 会创建一个子图(只有一个操作是k1 张量的平均值)所以你得到k1_mean.grad。请注意k1 grad 在这种情况下也将是None,因为它不需要渐变。

    默认情况下tensorrequired_grad 设置为False,因此图形不必执行不必要的操作。要获得您想要的行为,请将您的 t 更改为:

    t = torch.rand((2, 3), requires_grad=True)
    

    (每个张量函数都有requires_grad 参数)。正如您可能预期的那样,这给了您:

    k1_grad_fn  <AddBackward0 object at 0x7fe3454d6100>
    t grad  None
    k1 mean= tensor(1.0396, grad_fn=<MeanBackward0>)
    t grad  tensor([[0.3093, 0.1177, 0.2888],
            [0.3093, 0.1177, 0.2888]])
    k1 grad  <AddBackward0 object at 0x7fe3454d6100>
    

    【讨论】:

    • 非常感谢。尽管如此,输出并不是我期望看到的。对于正确的数学解决方案,我将不胜感激。
    • 数学上是正确的,你是什么意思? tensor.mean() 按原样返回标量,k1 grad_fn 也是 Addbackward,正如预期的那样,k1 没有 grad 答案中给出,wb 需要 require_grad参数就像t,但那些不包含在输出中。如果我仍然没有抓住重点,请更具体地提出您的问题,或者具体说明您所关注的部分,但解释得不够深入。
    • 我的期望(可能是错误的)是在输出中看到权重矩阵,因为我正在做 T x W 对 t 的导数。输出是一个 2x4 矩阵,我想手动计算以进行验证
    • 只有当您像上面那样指定requires_grad 时,您才会针对t 执行此操作。 .grad 是一个矩阵,正如您在答案中看到的那样,它的形状与您最初创建的 2x3 相同。 PyTorch 怎么知道w.r.t. 你正在执行什么衍生?如果 PyTorch 将 每个操作的每个梯度 保留为 grad 属性,那将是一种浪费,因为它通常不需要并且只占用内存。
    • 再次感谢。我终于想通了。 documentcloud.adobe.com/link/…
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-05-19
    • 2017-01-05
    • 2021-10-14
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-12-17
    相关资源
    最近更新 更多