【问题标题】:Pytorch Tensor - How to get the index of a tensor given a multidimensional tensorPytorch 张量 - 如何在给定多维张量的情况下获取张量的索引
【发布时间】:2021-11-12 11:32:28
【问题描述】:

我有以下张量,我们称之为 lookup_table

tensor([266, 103,  84,  12,  32,  34,   1, 523,  22, 136, 268, 432,  53,  63,
        201,  51, 164,  69,  31,  42, 122, 131, 119,  36, 245,  60,  28,  81,
          9, 114, 105,   3,  41,  86, 150,  79, 104, 120,  74, 420,  39, 427,
         40,  59,  24, 126, 202, 222, 145, 429,  43,  30,  38,  55,  10, 141,
         85, 121, 203, 240,  96,   7,  64,  89, 127, 236, 117,  99,  54,  90,
         57,  11,  21,  62,  82,  25, 267,  75, 111, 518,  76,  56,  20,   2,
         61, 516,  80,  78, 555, 246, 133, 497,  33, 421,  58, 107,  92,  68,
         13, 113, 235, 875,  35,  98, 102,  27,  14,  15,  72,  37,  16,  50,
        517, 134, 223, 163,  91,  44,  17, 412,  18,  48,  23,   4,  29,  77,
          6, 110,  67,  45, 161, 254, 112,   8, 106,  19, 498, 101,   5, 157,
         83, 350, 154, 238, 115,  26, 142, 143])

我还有另一个张量,我们称之为 data,如下所示:

tensor([[517, 235, 236,  76,  81,  25, 110,  59, 245,  39],
        [523, 114, 350, 246,  30, 222,  39, 517, 106,   2],
        [ 35, 235, 120,  99, 266,  63, 236, 133, 412,  38],
        [134,   2, 497,  21,  78,  60, 142, 498,  24,  89],
        [ 60, 111, 120, 145,  91, 141, 164,  81, 350,  55]])

现在我想要一些看起来像这样的东西:

tensor([112, 100, ..., 40],
       [7, 29, ..., 2],
       ...,          ])

我想使用我的数据张量来获取查找表的索引。
基本上我想把这个矢量化:

(lookup_table == data).nonzero()

所以这适用于多维数组。

我已阅读此内容,但它们不适用于我的情况:
How Pytorch Tensor get the index of specific value
How Pytorch Tensor get the index of elements?
Pytorch tensor - How to get the indexes by a specific tensor

编辑:
我基本上是在寻找这个的优化/矢量化版本:

x_data = torch.stack([(lookuptable == data[0][i]).nonzero(as_tuple=False) for i in range(len(data[0]))]).flatten().unsqueeze(0)
print(x_data.size())
for o in range(1, len(data)):
    x_data = torch.cat((x_data, torch.stack([(lookuptable == data[o][i]).nonzero(as_tuple=False) for i in range(len(data[o]))]).flatten().unsqueeze(0)), dim=0)

EDIT 2 最小示例:
我们有 data 张量:

data = torch.Tensor([
        [523, 114, 350, 246,  30, 222,  39, 517, 106,   2],
        [ 35, 235, 120,  99, 266,  63, 236, 133, 412,  38],
        [555, 104,  14,  81,  55, 497, 222,  64,  57, 131]
])

我们有 lookup_table 张量,见上文。

如果我们将此代码应用于 2 个张量:

 # convert champion keys into index notation
x_data = torch.stack([(lookuptable == x[0][i]).nonzero(as_tuple=False) for i in range(len(x[0]))]).flatten().unsqueeze(0)
for o in range(1, len(data) - 1):
    x_data = torch.cat((x_data, torch.stack([(lookuptable == x[o][i]).nonzero(as_tuple=False) for i in range(len(x[o]))]).flatten().unsqueeze(0)), dim=0)

我们得到这样的输出:

tensor([[  7,  29, 141,  89,  51,  47,  40, 112, 134,  83],
        [102, 100,  37,  67,   0,  13,  65,  90, 119,  52],
        [ 88,  36, 106,  27,  53,  91,  47,  62,  70,  21]
       ])

这个输出是我想要的,就像我在上面所说的,它是张量数据的每个值在张量查找表中的位置的索引。 问题是这不是矢量化的。 而且我不知道如何对其进行矢量化。

【问题讨论】:

  • 尝试为lookup_tabledata 提供最少的示例,并给出您想要获得的确切输出。

标签: python multidimensional-array pytorch tensor


【解决方案1】:

另一种更快的方法,假设所有值都有一个有限的范围,并且是int64(在这里,我还假设它们是非负的,但这个限制可以解决):

准备工作:

sorted_lookup_table, indexes = torch.sort(lookup_table)
lut = torch.zeros(size=(sorted_lookup_table[-1]+1,), dtype=torch.int64)
lut[:] = -1 # "not found"
lut[sorted_lookup_table] = indexes

数据处理:

index_into_lookup_table = lut[data]

【讨论】:

  • 当我尝试使用lut[sorted_lookup_table] = indexes 时,我得到了错误IndexError: tensors used as indices must be long, byte or bool tensors,即使我尝试使用lut[sorted_lookup_table] = indexes.long() 并且张量应该是一个长张量。
  • 我注意到另一个解决方案很遗憾对我不起作用,因为在 lookup_table 中是数字之间的跳跃,所以当我使用 sorted_lookup_table, indexes = torch.sort(lookup_table) 时,我最终会在数字之间跳跃。
  • @Lupos 原件 (lookup_table) 必须是 long
【解决方案2】:

使用searchsorted

为每个输入元素扫描整个lookup_table 数组是非常低效的。先对查找表进行排序怎么样(这只需要做一次)

sorted_lookup_table, indexes = torch.sort(lookup_table)

然后使用searchsorted

index_into_sorted = torch.searchsorted(sorted_lookup_table, data)

如果您需要原始lookup_table 的索引,可以使用

index_into_lookup_table = indexes[index_into_sorted]

【讨论】:

  • 您可能需要重命名 sorted,因为它是 Python 中的保留关键字!
  • 好的,它不是保留的,你的代码会运行,是的。但使用 builtin function 名称作为变量名被认为是不好的做法。
  • 这个解决方案效果很好,非常感谢。我从 pytorch 不知道这些方法。
猜你喜欢
  • 2019-02-05
  • 2020-09-26
  • 2019-01-13
  • 2019-09-15
  • 2021-06-27
  • 2020-05-11
  • 2021-12-30
  • 2020-03-28
  • 2019-11-26
相关资源
最近更新 更多