【发布时间】:2021-12-19 00:38:26
【问题描述】:
我正在尝试找到一种高效的 numpy 循环(与 python 循环相反)方法来获取每组成本最低的数据点的索引。类似于np.minimum.at 所做的事情,但使用“argminimum”而不是最小值。 (并且np.argmin.at 不存在)。
以下演示了我正在寻找的内容:
names, groups, costs = zip(*[
('a', 0, 2.0), # no (d is lower cost)
('b', 1, 3.), # yes (tied but first)
('c', 2, 3.), # yes (only one)
('d', 0, 1.2), # yes
('e', 3, 3.), # no (k is lower)
('f', 4, 3.), # no (j is lower)
('g', 5, 3.), # yes
('h', 1, 3.), # no (tied but not first)
('i', 0, 4.), # no (d is lower)
('j', 4, 2.3), # yes
('k', 3, 0.6), # yes
('l', 5, 7.), # no (g is lower)
])
mask = get_minimal_unique_index_mask(arr=np.array(groups), values=np.array(costs))
selected = ''.join(c for c, m in zip(names, mask) if m)
expected = 'bcdgjk'
assert selected == expected, f"Selected: '{selected}'. Expected: '{expected}'"
我正在尝试找到get_minimal_unique_index_mask 的有效实现。我知道我可以使用 dicts 和 python 循环轻松做到这一点:
def get_minimal_unique_index_mask(groups: Array['N', Any], values: Array['N', float]) -> Array['N', bool]:
min_ixs_vals = {}
for i, (group, val) in enumerate(zip(groups, values)):
if group not in min_ixs_vals:
min_ixs_vals[group] = i
else:
min_ixs_vals[group] = i if val < values[min_ixs_vals[group]] else min_ixs_vals[group]
argmin_per_group_mask = np.zeros(len(groups), dtype=bool)
argmin_per_group_mask[list(min_ixs_vals.values())] = True
return argmin_per_group_mask
...上述方法有效,但在 python 中循环,因此会很慢。我想知道是否有一个聪明的 numpy 方法来做同样的事情。
【问题讨论】:
-
我可能会为此使用 pandas 的 groupby 功能。
-
谢谢 Quang,我试过了,现在似乎有一个工作功能。