【发布时间】:2019-10-25 00:11:00
【问题描述】:
给定一个数组和相同形状的掩码,我想要相同形状的掩码输出并包含 0,其中掩码为 False。
例如,
# input array
img = torch.randn(2, 2)
print(img)
# tensor([[0.4684, 0.8316],
# [0.8635, 0.4228]])
print(img.shape)
# torch.Size([2, 2])
# mask
mask = torch.BoolTensor(2, 2)
print(mask)
# tensor([[False, True],
# [ True, True]])
print(mask.shape)
# torch.Size([2, 2])
# expected masked output of shape 2x2
# tensor([[0, 0.8316],
# [0.8635, 0.4228]])
问题:屏蔽改变了输出的形状,如下所示:
#1: shape changed
img[mask]
# tensor([0.8316, 0.8635, 0.4228])
【问题讨论】: