【发布时间】:2020-05-22 14:08:56
【问题描述】:
假设,我有一个 3D 张量 A
A = torch.arange(24).view(4, 3, 2)
print(A)
并要求使用 2D 张量对其进行屏蔽
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
使用 PyTorch 中的 masked_select 功能会导致以下错误。
torch.masked_select(X, (mask == 1))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
12
13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
15 #Y = X * mask_
16 print(Y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
如何用 2D 蒙版屏蔽 3D 张量并保持原始向量的尺寸?任何提示将不胜感激。
【问题讨论】: