【发布时间】:2019-09-04 12:28:52
【问题描述】:
我有一个尺寸为[NxQxD] 的张量M 和一个指数为idx 的一维张量(尺寸为N)。我想有效地创建一个尺寸为[NxQxD] 的张量mask,使得mask[i,j,k] = 1 iff j <= idx[i],即我只想在@987654330 的第二维(dim=1)中保留Q 中的idx[i] 第一维@,对于每一行 i。
谢谢!
【问题讨论】:
我有一个尺寸为[NxQxD] 的张量M 和一个指数为idx 的一维张量(尺寸为N)。我想有效地创建一个尺寸为[NxQxD] 的张量mask,使得mask[i,j,k] = 1 iff j <= idx[i],即我只想在@987654330 的第二维(dim=1)中保留Q 中的idx[i] 第一维@,对于每一行 i。
谢谢!
【问题讨论】:
事实证明,这可以通过广播技巧来完成:
mask_2d = torch.arange(Q)[None, :] < idx[:, None] #(N,Q)
mask_3d = mask[..., None] #(N,Q,1)
masked = mask.float() * data
【讨论】: