【问题标题】:mask first k elements in a 3D tensor in PyTorch (different k for each row)在 PyTorch 中屏蔽 3D 张量中的前 k 个元素(每行的 k 不同)
【发布时间】:2019-09-04 12:28:52
【问题描述】:

我有一个尺寸为[NxQxD] 的张量M 和一个指数为idx 的一维张量(尺寸为N)。我想有效地创建一个尺寸为[NxQxD] 的张量mask,使得mask[i,j,k] = 1 iff j <= idx[i],即我只想在@987654330 的第二维(dim=1)中保留Q 中的idx[i] 第一维@,对于每一行 i

谢谢!

【问题讨论】:

    标签: python pytorch tensor


    【解决方案1】:

    事实证明,这可以通过广播技巧来完成:

    mask_2d = torch.arange(Q)[None, :] < idx[:, None] #(N,Q)
    mask_3d = mask[..., None] #(N,Q,1)
    masked = mask.float() * data
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-04-20
      • 1970-01-01
      • 2021-01-22
      • 1970-01-01
      • 2015-10-31
      • 1970-01-01
      相关资源
      最近更新 更多