【问题标题】:Replace torch.gather by other operator?由其他操作员替换 torch.gather?
【发布时间】:2022-01-19 17:50:40
【问题描述】:

我有一个脚本代码,其中x1x2 大小为1x68x8x8

 tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
 x1 = x1.view(tmp_batch*tmp_channel, -1)        
 max_ids = torch.argmax(x1, 1)            
 max_ids = max_ids.view(-1, 1)
            
 x2 = x2.view(tmp_batch*tmp_channel, -1)
 outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1

至于上面的代码,当我使用旧的onnx 时,我遇到了torch.gather 的问题。因此,我想找到一个替代解决方案,将toch.gather 替换为其他运算符,但输出与上述代码相同。你能给我一些建议吗?

【问题讨论】:

标签: python pytorch


【解决方案1】:

一种解决方法是使用等效的 numpy 方法。如果您在某处包含 import numpy as np 语句,则可以执行以下操作。

outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))

如果这给您带来了与毕业相关的错误,请尝试

outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))

没有 numpy 的方法:在这种情况下,max_ids 似乎每行只包含一个条目。因此,我相信以下方法会起作用:

max_ids = torch.argmax(x1, 1) # do not reshape
            
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]

【讨论】:

  • 但是我需要torch的操作员
  • @KimHee 我有另一个想法,看看我的最新编辑
  • 很好的答案,我想我会遵循 torch.arrange 的建议。但是,如果我有 max_ids_nb = max_ids.repeat(1, num_nb).view(-1, 1) outputs_nb_x = outputs_nb_x.view(tmp_batch*num_nb*tmp_channel, -1) outputs_nb_x_select = torch.gather(outputs_nb_x, 1, max_ids_nb) ,那么如何使用 torch.arange 作为您的建议
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2023-02-04
  • 2012-01-26
  • 1970-01-01
  • 2022-06-16
  • 2018-11-21
  • 2019-02-18
  • 1970-01-01
相关资源
最近更新 更多