【问题标题】:How to convert a list of Torch tensor with grad to tensor如何将带有 grad 的 Torch 张量列表转换为张量
【发布时间】:2021-01-14 03:48:08
【问题描述】:

我有一个名为 pts 的变量,其形状为 [batch, ch, h, w]。这是一个热图,我想将其转换为第二个坐标。目标是,pts_o = heatmap_to_pts(pts),其中 pts_o 将是 [batch, ch, 2]。到目前为止,我已经编写了这个函数,

def heatmap_to_pts(self, pts):  <- pts [batch, 68, 128, 128]
    
    pt_num = []
    
    for i in range(len(pts)):
        
        pt = pts[i]
        if type(pt) == torch.Tensor:

            d = torch.tensor(128)                                                   * get the   
            m = pt.view(68, -1).argmax(1)                                           * indices
            indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)  * from heatmaps
        
            pt_num.append(indices.type(torch.DoubleTensor) )   <- store the indices in a list

    b = torch.Tensor(68, 2)                   * trying to convert
    c = torch.cat(pt_num, out=b) *error*      * a list of tensors with grad
    c = c.reshape(68,2)                       * to a tensor like [batch, 68, 2]

    return c

错误提示“cat(): 带有 out=... 参数的函数不支持自动微分,但其中一个参数需要 grad。”。它无法进行操作,因为 pt_num 中的张量需要 grad"。

如何将该列表转换为张量?

【问题讨论】:

    标签: computer-vision pytorch heatmap tensor face-alignment


    【解决方案1】:

    错误说,

    cat():带有 out=... 参数的函数不支持自动微分,但其中一个参数需要 grad。

    这意味着像 torch.cat() 这样的函数的输出,作为 out= kwarg 不能用作 autograd 引擎(执行自动微分)的输入。

    原因是张量(在您的 Python 列表 pt_num 中)具有不同的 requires_grad 属性值,即,一些张量具有 requires_grad=True,而其中一些具有 requires_grad=False

    在您的代码中,以下行(逻辑上)很麻烦:

    c = torch.cat(pt_num, out=b) 
    

    torch.cat() 的返回值,无论您是否使用out= kwarg,都是沿上述维度串联的张量。

    所以,张量 c 已经是 pt_num 中各个张量的串联版本。使用out=b 冗余。因此,您可以简单地摆脱out=b,一切都会好起来的。

    c = torch.cat(pt_num)
    

    【讨论】:

      猜你喜欢
      • 2020-12-12
      • 1970-01-01
      • 2021-11-10
      • 2020-02-11
      • 2017-07-21
      • 2020-08-05
      • 1970-01-01
      • 2021-02-02
      • 2014-10-04
      相关资源
      最近更新 更多