【发布时间】:2021-08-09 20:18:38
【问题描述】:
假设我有标签可以将数据分成几组。现在我想找到每个集群的最大 w.r.t 原始数组的索引。例如:
import numpy as np
labels = np.array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0])
y = np.random.randint(0, 9, len(labels)) #array([6, 7, 5, 4, 2, 8, 4, 4, 5, 6, 4])
我想要 [1, 5] 因为对于集群 1,索引 1 处的最大值为 7,而对于集群 0,索引 5 处的最大值为 8。是否有可能在没有 for 循环的情况下得到它?
我的幼稚方案供参考:
out = []
for i in [0, 1]:
temp = y.copy()
temp[labels == i] = -1
out.append(np.argmax(temp))
【问题讨论】:
-
在这种情况下 8 不应该在索引 2 处吗?
-
如果我们只考虑分组数组,则为 2,但相对于原始数组为 5
-
我相信这个答案会奏效,但它对我来说仍然有点难看。明天有时间我会努力改进的。