【发布时间】:2021-11-12 11:32:28
【问题描述】:
我有以下张量,我们称之为 lookup_table:
tensor([266, 103, 84, 12, 32, 34, 1, 523, 22, 136, 268, 432, 53, 63,
201, 51, 164, 69, 31, 42, 122, 131, 119, 36, 245, 60, 28, 81,
9, 114, 105, 3, 41, 86, 150, 79, 104, 120, 74, 420, 39, 427,
40, 59, 24, 126, 202, 222, 145, 429, 43, 30, 38, 55, 10, 141,
85, 121, 203, 240, 96, 7, 64, 89, 127, 236, 117, 99, 54, 90,
57, 11, 21, 62, 82, 25, 267, 75, 111, 518, 76, 56, 20, 2,
61, 516, 80, 78, 555, 246, 133, 497, 33, 421, 58, 107, 92, 68,
13, 113, 235, 875, 35, 98, 102, 27, 14, 15, 72, 37, 16, 50,
517, 134, 223, 163, 91, 44, 17, 412, 18, 48, 23, 4, 29, 77,
6, 110, 67, 45, 161, 254, 112, 8, 106, 19, 498, 101, 5, 157,
83, 350, 154, 238, 115, 26, 142, 143])
我还有另一个张量,我们称之为 data,如下所示:
tensor([[517, 235, 236, 76, 81, 25, 110, 59, 245, 39],
[523, 114, 350, 246, 30, 222, 39, 517, 106, 2],
[ 35, 235, 120, 99, 266, 63, 236, 133, 412, 38],
[134, 2, 497, 21, 78, 60, 142, 498, 24, 89],
[ 60, 111, 120, 145, 91, 141, 164, 81, 350, 55]])
现在我想要一些看起来像这样的东西:
tensor([112, 100, ..., 40],
[7, 29, ..., 2],
..., ])
我想使用我的数据张量来获取查找表的索引。
基本上我想把这个矢量化:
(lookup_table == data).nonzero()
所以这适用于多维数组。
我已阅读此内容,但它们不适用于我的情况:
How Pytorch Tensor get the index of specific value
How Pytorch Tensor get the index of elements?
Pytorch tensor - How to get the indexes by a specific tensor
编辑:
我基本上是在寻找这个的优化/矢量化版本:
x_data = torch.stack([(lookuptable == data[0][i]).nonzero(as_tuple=False) for i in range(len(data[0]))]).flatten().unsqueeze(0)
print(x_data.size())
for o in range(1, len(data)):
x_data = torch.cat((x_data, torch.stack([(lookuptable == data[o][i]).nonzero(as_tuple=False) for i in range(len(data[o]))]).flatten().unsqueeze(0)), dim=0)
EDIT 2 最小示例:
我们有 data 张量:
data = torch.Tensor([
[523, 114, 350, 246, 30, 222, 39, 517, 106, 2],
[ 35, 235, 120, 99, 266, 63, 236, 133, 412, 38],
[555, 104, 14, 81, 55, 497, 222, 64, 57, 131]
])
我们有 lookup_table 张量,见上文。
如果我们将此代码应用于 2 个张量:
# convert champion keys into index notation
x_data = torch.stack([(lookuptable == x[0][i]).nonzero(as_tuple=False) for i in range(len(x[0]))]).flatten().unsqueeze(0)
for o in range(1, len(data) - 1):
x_data = torch.cat((x_data, torch.stack([(lookuptable == x[o][i]).nonzero(as_tuple=False) for i in range(len(x[o]))]).flatten().unsqueeze(0)), dim=0)
我们得到这样的输出:
tensor([[ 7, 29, 141, 89, 51, 47, 40, 112, 134, 83],
[102, 100, 37, 67, 0, 13, 65, 90, 119, 52],
[ 88, 36, 106, 27, 53, 91, 47, 62, 70, 21]
])
这个输出是我想要的,就像我在上面所说的,它是张量数据的每个值在张量查找表中的位置的索引。 问题是这不是矢量化的。 而且我不知道如何对其进行矢量化。
【问题讨论】:
-
尝试为
lookup_table和data提供最少的示例,并给出您想要获得的确切输出。
标签: python multidimensional-array pytorch tensor