您的功能很好,不是您出错的原因(使用PyTorch 1.6.0,如果您使用的是其他版本,请更新您的依赖项)。
下面的代码可以正常工作:
import torch
import torch.nn as nn
import torch.nn.functional as F
T = 5
K = 3
inputs = torch.tensor(
[[1, 2, 3,], [1, 2, 3,], [1, 2, 3,], [1, 2, 3,], [1, 2, 3,],],
requires_grad=True,
dtype=torch.float,
)
right_pad = T - K + 1
output = F.pad(inputs, (0, right_pad), "constant", value=0)
output = output.flatten()[:-T].reshape(T, T)
output.sum().backward()
print(inputs.grad)
请注意,我已将dtype 明确指定为torch.float,因为您不能将backprop 指定为整数。
view 和 slice 永远不会破坏反向传播,因为 gradient 连接到单个值,无论它被视为 1D 还是未压缩的 2D 或其他.这些没有就地修改。就地修改破坏梯度可能是:
output[0, 3] = 15.
此外,您的解决方案返回以下内容:
tensor([[1., 2., 3., 0., 0.],
[0., 1., 2., 3., 0.],
[0., 0., 1., 2., 3.],
[0., 0., 0., 1., 2.],
[3., 0., 0., 0., 1.]], grad_fn=<ViewBackward>)
所以你在左下角有一个3。如果这不是您所期望的,您应该在output = output.flatten()[:-T].reshape(T,T) 之后添加这一行(与1 相乘的上三角矩阵):
output *= torch.triu(torch.ones_like(output))
给出:
tensor([[1., 2., 3., 0., 0.],
[0., 1., 2., 3., 0.],
[0., 0., 1., 2., 3.],
[0., 0., 0., 1., 2.],
[0., 0., 0., 0., 1.]], grad_fn=<AsStridedBackward>)
还有inputs.grad:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 0.],
[1., 0., 0.]])