【问题标题】:Most efficient way to operate on a n-dim array based on a reference n-dim array基于参考 n-dim 数组对 n-dim 数组进行操作的最有效方法
【发布时间】:2019-10-29 14:07:05
【问题描述】:

我有两个相同形状的 numpy 数组:dat_araref_ara

我想对axis = -1dat_ara 执行操作op_func,但是我只想对每个数组中的选定值切片进行操作,当阈值thres 为时指定切片被引用数组ref_ara交叉。

为了说明,在数组只是 2-dim 的简单情况下,我有:

thres = 4

op_func = np.average

ref_ara = array([[1, 2, 1, 4, 3, 5, 1, 5, 2, 5],
                 [1, 2, 2, 1, 1, 1, 2, 7, 5, 8],
                 [2, 3, 2, 5, 1, 6, 5, 2, 7, 3]]) 

dat_ara = array([[1, 0, 0, 1, 1, 1, 1, 0, 1, 1],
                 [1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
                 [1, 0, 1, 1, 1, 1, 0, 1, 1, 1]]) 

我们看到thresref_araaxis=0 的第一个、第二个和第三个数组的第5 个、第7 个和第3 个索引中被破坏。因此我想要的结果是

out_ara = array([op_func(array([1, 0, 0, 1, 1, 1]), 
                 op_func(array([1, 1, 1, 1, 1, 1, 1, 0]),
                 op_func(array([1, 0, 1, 1])])

这个问题很困难,因为它需要引用ref_ara。如果不是这样,我可以简单地使用numpy.apply_along_axis

我已经尝试扩展两个数组的维度以将它们关联起来进行计算,即:

assos_ara = np.append(np.expand_dims(dat_ara, axis=-1), np.expand_dims(ref_ara, axis=-1), axis=-1)

但同样,numpy.apply_along_axis 要求输入函数只能对 1-dim 数组进行操作,因此我仍然无法使用该函数。

我知道的唯一另一种方法是按索引迭代数组,但是,由于数组的两个数组的维数不断变化,这是一个棘手的问题,而且计算效率不高。

我很想使用矢量化函数来帮助这个过程。最有效的方法是什么?

【问题讨论】:

  • 今天是你的幸运日。我刚刚问过这个:stackoverflow.com/q/58595650/2988730
  • 不错!这个问题确实看起来很相似,需要一些时间来消化代码..
  • 这是您需要的构建块之一。我正在给你写一个答案。

标签: python arrays numpy vectorization


【解决方案1】:

这是屏蔽数组的一个很好的用例,因为它们允许您对部分数据执行正常的 numpy 操作。

假设每一行至少包含一个大于阈值的值。您可以将断点的索引计算为

breaks = np.argmax(ref_ara > thres, axis=-1)   # 5, 7, 3

然后您可以使用answer 到我之前链接的question 创建一个掩码。掩码通常是处理 numpy 中不规则形状数据的最佳方式。

mask = np.arange(ref_ara.shape[-1]) <= breaks.reshape(*breaks.shape, 1)

在这里,我们不需要对arange 做任何花哨的事情,因为它位于最后一个维度。如果不是这种情况,您可能希望将 1 插入到范围所在的中断形状中,并用 1 填充范围形状的尾部。

现在掩码数组和 ufunc 解决方案略有不同。掩码数组版本更通用,所以先来:

data = np.ma.array(data_ara, mask=~mask)

掩码数组从普通布尔索引的方式向后解释掩码,因此我们反转掩码。或者,您可以使用&gt; 而不是&lt;= 来计算掩码。计算现在很简单:

out_ara = np.ma.average(data, axis=-1).data

一个不太通用的替代方法是将您的操作分解为 ufunc,并使用它们提供的掩码。这对np.average 来说很容易,只有np.sumnp.divide,但对于更复杂的操作可能会更难。

从 numpy 1.17.0 开始,np.sum 有一个 where 关键字:

out_ara = np.sum(dat_ara, where=mask, axis=-1) / breaks

【讨论】:

  • 真棒答案,为了更简洁,我认为breaks.reshape(*breaks.shape, 1)可以替换为breaks[..., None]
  • 澄清一下,有必要使用np.ma.average吗?还是np.average 和任何其他numpy ufunc 也能正常工作?
  • @Tian。你可以同时尝试,但我认为你需要 np.ma.average 来尊重面具。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2022-12-11
  • 1970-01-01
  • 2021-10-08
  • 2017-03-24
  • 1970-01-01
  • 1970-01-01
  • 2013-07-07
相关资源
最近更新 更多