【问题标题】:How to get index of multiple, possibly different, elements in numpy?如何在numpy中获取多个可能不同的元素的索引?
【发布时间】:2021-11-24 12:38:25
【问题描述】:

我有一个 numpy 数组,其中包含许多行,大致如下:

0, 50, 50, 2, 50, 1, 50, 99, 50, 50
50, 2, 1, 50, 50, 50, 98, 50, 50, 50
0, 50, 50, 98, 50, 1, 50, 50, 50, 50
0, 50, 50, 50, 50, 99, 50, 50, 2, 50
2, 50, 50, 0, 98, 1, 50, 50, 50, 50

给我一​​个变量n

  • 从 0 到 n 的每个数字,其中一个可能缺失。在上面的示例中,n=2。
  • 可能是 98,如果缺少数字,它将代替缺少的数字。
  • 可能是 99,如果缺少数字并且还没有 98,它将代替缺少的数字。
  • 许多 50 年代。

我想要的是一个数组,其中包含第一行中所有 0 的索引、第二行中所有 1 的索引、第三行中所有 2 的索引等。例如,我想要的输出是这样的:

0, 6, 0, 0, 3
5, 2, 5, 5, 5
3, 1, 3, 8, 0

您可能已经注意到一个问题:有时,恰好其中一个数字被替换为 98 或 99。编写一个 for 循环来确定哪个数字(如果有)被替换并使用它非常容易获取索引数组。

有没有办法用 numpy 做到这一点?

【问题讨论】:

  • 我读了 3 遍我不知道你在做什么
  • @barker,基本上我试图获取每行中 0 的索引,然后是每行中 1 的索引,等等。问题是,可能没有 1 in某行,在这种情况下,我需要查看 98/99 的索引。
  • 这与您想要的输出不匹配,第一个索引行 0,6,0,0,3 不符合该逻辑
  • 为什么不呢?在第一行中,0 在索引 0 处。在第二行中,0 在索引 6 处。在第三行中,在索引 0 处。在第四行中,在索引 0 处。在第五行中,它是索引 3.

标签: python arrays numpy multidimensional-array indexing


【解决方案1】:

我不认为你在这里没有一个 for 循环就可以逃脱。但是,您可以这样做。

对于n 中的每个数字,找到它已知的所有位置。示例:

locations = np.argwhere(data == 1)
print(locations)
[[0 5]
 [1 2]
 [2 5]
 [4 5]]

然后您可以将其转换为地图,以便在n 中轻松查找每个号码:

known = {
    i: dict(np.argwhere(data == i))
    for i in range(n + 1)
}
pprint(known)
{0: {0: 0, 2: 0, 3: 0, 4: 3},
 1: {0: 5, 1: 2, 2: 5, 4: 5},
 2: {0: 3, 1: 1, 3: 8, 4: 0}}

对未知数做同样的事情:

unknown = dict(np.argwhere((data == 98) | (data == 99)))
pprint(unknown)
{0: 7, 1: 6, 2: 3, 3: 5, 4: 4}

现在对于结果中的每个位置,您都可以在已知列表中查找索引并回退到未知位置。

result = np.array(
    [
        [known[i].get(j, unknown.get(j)) for j in range(len(data))]
        for i in range(n + 1)
    ]
)
print(result)
[[0 6 0 0 3]
 [5 2 5 5 5]
 [3 1 3 8 0]]

奖励:使用字典构造函数和解包:

from collections import OrderedDict

unknown = np.argwhere((data == 98) | (data == 99))
results = np.array([
    [*OrderedDict((*unknown, *np.argwhere(data == i))).values()]
    for i in range(n + 1)
])
print(results)

【讨论】:

  • 非常感谢!这似乎是更可扩展的答案,以防万一发现此问题的人有类似问题。
【解决方案2】:

下面的numpy 解决方案相当积极地使用了 OP 中列出的假设。如果不能 100% 保证,则可能需要进行更多检查。

这里有点聪明(即使我自己这么说)是使用数据数组本身来查找索引的正确目的地。例如,所有 2 都需要将它们的索引存储在输出数组的第 2 行中。使用它,我们可以在单个操作中批量存储大部分索引。

示例输入在数组data

n = 2
y,x = data.shape
out = np.empty((y,n+1),int)
# find 98 falling back to 99 if necessary
# and fill output array with their indices
# if neither exists some nonsense will be written but that does no harm
# most of this will be overwritten later
out.T[...] = ((data-98)&127).argmin(axis=1)
# find n+1 lowest values in each row
idx = data.argpartition(n,axis=1)[:,:n+1]
# construct auxiliary indexer
yr = np.arange(y)[:,None]
# put indices of low values where they belong
out[yr,data[yr,idx[:,:-1]]] = idx[:,:-1]
#      ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ 
#         the clever bit
# rows with no missing number still need the last value
nomiss, = (data[range(y),idx[:,n]] == n).nonzero()
out[nomiss,n] = idx[nomiss,n]
# admire
print(out.T)

输出:

[[0 6 0 0 3]
 [5 2 5 5 5]
 [3 1 3 8 0]]

【讨论】:

  • 哇,这是魔法。非常感谢——这一定是我见过的最出色的 numpy 代码。
猜你喜欢
  • 2018-11-21
  • 2015-11-18
  • 2023-03-14
  • 1970-01-01
  • 1970-01-01
  • 2021-12-01
  • 2019-10-09
  • 2013-06-26
  • 2011-07-25
相关资源
最近更新 更多