【问题标题】:How can I find the n indices of minimum elements for each row using numpy?如何使用 numpy 找到每行的最小元素的 n 个索引?
【发布时间】:2021-10-30 17:59:14
【问题描述】:

例如:

n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])

结果,我想要:

[ [0, 0, 2],
  [1, 0, 1],
  [2, 1, 2] ]
            

每行的第一个数字就是 p1 中的行号。剩余的 n 个数字是行的最小元素的索引。所以:

[0, 0, 2]
 # 0 is the index of the first row in p1.
 # (0, 2 are the indices of minimum elements of the row)


[1, 0, 1]
# 1 is the index of the second row in p1.
# (0, 1 are the indices of minimum elements of the row)

[2, 1, 2]
# 2 is the index of the third row in p1.
# (1, 2 are the indices of minimum elements of the row)

非常感谢!!!

【问题讨论】:

  • minimum elements 是什么?

标签: python numpy


【解决方案1】:

使用np.argpartition 找出前两个最小值:

import numpy as np

n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])

pos = np.argpartition(p1, axis=1, kth=2)

res = np.hstack([np.arange(3)[:, None], np.sort(pos[:, :2])])
print(res)

输出

[[0 0 2]
 [1 0 1]
 [2 1 2]]

找到最小值后,使用np.hstack 将行的索引连接起来。

【讨论】:

    猜你喜欢
    • 2017-12-13
    • 2012-12-17
    • 2019-05-31
    • 1970-01-01
    • 2019-01-21
    • 2013-07-30
    • 1970-01-01
    • 2018-09-01
    相关资源
    最近更新 更多