【发布时间】:2021-05-25 23:14:45
【问题描述】:
我有一个张量,例如: 输入:
a = torch.rand(2,3,5)
输出:
tensor([[[0.2764, 0.2209, 0.8007, 0.1246, 0.4302],
[0.9716, 0.8063, 0.3904, 0.7574, 0.2392],
[0.3366, 0.4209, 0.0527, 0.1328, 0.0441]],
[[0.8166, 0.6519, 0.5450, 0.3072, 0.2716],
[0.0583, 0.0613, 0.8984, 0.0110, 0.4744],
[0.2269, 0.2693, 0.6447, 0.6078, 0.6148]]])
我怎样才能得到:
tensor([[[0.2764, 0.2209, 0.8007, 0.1246, 0.4302],
[0.9716, 0.8063, 0.3904, 0.7574, 0.2392]],
[[0.8166, 0.6519, 0.5450, 0.3072, 0.2716],
[0.2269, 0.2693, 0.6447, 0.6078, 0.6148]]])
我试过gather(),但是没用。
【问题讨论】:
-
你能澄清输出的预期形状吗?为什么你认为
gather()是你想要的功能?