【发布时间】:2022-01-19 17:50:40
【问题描述】:
我有一个脚本代码,其中x1 和x2 大小为1x68x8x8
tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
x1 = x1.view(tmp_batch*tmp_channel, -1)
max_ids = torch.argmax(x1, 1)
max_ids = max_ids.view(-1, 1)
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1
至于上面的代码,当我使用旧的onnx 时,我遇到了torch.gather 的问题。因此,我想找到一个替代解决方案,将toch.gather 替换为其他运算符,但输出与上述代码相同。你能给我一些建议吗?
【问题讨论】:
-
另见this post。