【问题标题】:Numpy: replace each element in a row by the maximum of other elements in the same rowNumpy:将一行中的每个元素替换为同一行中其他元素的最大值
【发布时间】:2018-01-24 08:30:55
【问题描述】:

假设我们有一个这样的二维数组:

>>> a
array([[1, 1, 2],
   [0, 2, 2],
   [2, 2, 0],
   [0, 2, 0]])

对于每一行,我想用同一行中其他 2 个元素中的最大值替换每个元素。

我已经找到了如何使用 numpy.amax 和一个标识数组分别为每一列执行此操作,如下所示:

>>> np.amax(a*(1-np.eye(3)[0]), axis=1)
array([ 2.,  2.,  2.,  2.])
>>> np.amax(a*(1-np.eye(3)[1]), axis=1)
array([ 2.,  2.,  2.,  0.])
>>> np.amax(a*(1-np.eye(3)[2]), axis=1)
array([ 1.,  2.,  2.,  2.])

但我想知道是否有办法避免 for 循环并直接获得结果,在这种情况下应该如下所示:

>>> numpy_magic(a)
array([[2, 2, 1],
   [2, 2, 2],
   [2, 2, 2],
   [2, 0, 2]])

编辑:在控制台中玩了几个小时后,我终于想出了我正在寻找的解决方案。准备好一些令人兴奋的一行代码:

np.amax(a[[range(a.shape[0])]*a.shape[1],:][(np.eye(a.shape[1]) == 0)[:,[range(a.shape[1])*a.shape[0]]].reshape(a.shape[1],a.shape[0],a.shape[1])].reshape((a.shape[1],a.shape[0],a.shape[1]-1)),axis=2).transpose()
array([[2, 2, 1],
   [2, 2, 2],
   [2, 2, 2],
   [2, 0, 2]])

Edit2:Paul 提出了一个更具可读性和更快的替代方案,即:

np.max(a[:, np.where(~np.identity(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)

在对这 3 种备选方案进行计时后,Paul 的两种解决方案在每种情况下都快 4 倍(我已针对 200 行的 2、3 和 4 列进行了基准测试)。恭喜这些令人惊叹的代码!

上次编辑(抱歉):在将 np.identity 替换为更快的 np.eye 之后,我们现在有了最快最简洁的解决方案:

np.max(a[:, np.where(~np.eye(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)

【问题讨论】:

  • 我讨厌批评,但 是您要找的吗?看,这是一个长度不到一半的可读版本:np.max(a[:, np.where(~np.identity(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)
  • 干得好@PaulPanzer!我已经编辑了问题并为你的答案加了星标,即使我更喜欢最后一个,如果你用 np.eye 替换 np.identity 也会更快。非常感谢你!
  • 谢谢。我有点惊讶这是最快的,因为根据我的经验,高级索引有点慢。也许是因为它只是少数几个索引。另外,我以前认为eyeidentity 本质上是一样的。很多东西要学...

标签: python numpy


【解决方案1】:

这里有两种解决方案,一种专门为max 设计,另一种更通用,也适用于其他操作。

利用每一行中除了可能的一个最大值之外的所有元素都是整行的最大值这一事实,我们可以使用argpartition 来廉价地找到最大的两个元素的索引。然后在最大的位置,我们把第二大的值和其他地方的最大值放在一起。也适用于超过 3 列。

>>> a
array([[6, 0, 8, 8, 0, 4, 4, 5],
       [3, 1, 5, 0, 9, 0, 3, 6],
       [1, 6, 8, 3, 4, 7, 3, 7],
       [2, 1, 6, 2, 9, 1, 8, 9],
       [7, 3, 9, 5, 3, 7, 4, 3],
       [3, 4, 3, 5, 8, 2, 2, 4],
       [4, 1, 7, 9, 2, 5, 9, 6],
       [5, 6, 8, 5, 5, 3, 3, 3]])
>>> 
>>> M, N = a.shape
>>> result = np.empty_like(a)
>>> largest_two = np.argpartition(a, N-2, axis=-1)
>>> rng = np.arange(M)
>>> result[...] = a[rng, largest_two[:, -1], None]
>>> result[rng, largest_two[:, -1]] = a[rng, largest_two[:, -2]]>>> 
>>> result
array([[8, 8, 8, 8, 8, 8, 8, 8],
       [9, 9, 9, 9, 6, 9, 9, 9],
       [8, 8, 7, 8, 8, 8, 8, 8],
       [9, 9, 9, 9, 9, 9, 9, 9],
       [9, 9, 7, 9, 9, 9, 9, 9],
       [8, 8, 8, 8, 5, 8, 8, 8],
       [9, 9, 9, 9, 9, 9, 9, 9],
       [8, 8, 6, 8, 8, 8, 8, 8]])

此解决方案取决于 max 的特定属性。

一个更通用的解决方案,例如也适用于sum 而不是max 将是。将a 的两个副本粘在一起(并排,而不是彼此重叠)。所以这些行类似于a0 a1 a2 a3 a0 a1 a2 a3。对于索引x,我们可以通过切片[x+1:x+4] 得到除ax 之外的所有内容。为此,我们使用stride_tricks:

>>> a
array([[2, 6, 0],
       [5, 0, 0],
       [5, 0, 9],
       [6, 4, 4],
       [5, 0, 8],
       [1, 7, 5],
       [9, 7, 7],
       [4, 4, 3]])
>>> M, N = a.shape
>>> aa = np.c_[a, a]
>>> ast = np.lib.stride_tricks.as_strided(aa, (M, N+1, N-1), aa.strides + aa.strides[1:])
>>> result = np.max(ast[:, 1:, :], axis=-1)
>>> result
array([[6, 2, 6],
       [0, 5, 5],
       [9, 9, 5],
       [4, 6, 6],
       [8, 8, 5],
       [7, 5, 7],
       [7, 9, 9],
       [4, 4, 4]])

# use sum instead of max
>>> result = np.sum(ast[:, 1:, :], axis=-1)
>>> result
array([[ 6,  2,  8],
       [ 0,  5,  5],
       [ 9, 14,  5],
       [ 8, 10, 10],
       [ 8, 13,  5],
       [12,  6,  8],
       [14, 16, 16],
       [ 7,  7,  8]])

【讨论】:

    【解决方案2】:

    列表理解解决方案。

    np.array([np.amax(a * (1 - np.eye(3)[j]), axis=1) for j in range(a.shape[1])]).T
    

    【讨论】:

      【解决方案3】:

      类似于@Ethan 的回答,但有np.delete()np.max()np.dstack()

      np.dstack([np.max(np.delete(a, i, 1), axis=1) for i in range(a.shape[1])])
      
      array([[2, 2, 1],
             [2, 2, 2],
             [2, 2, 2],
             [2, 0, 2]])
      
      • delete()依次“过滤”出每一列;
      • max() 查找剩余两列的行最大值
      • dstack() 堆叠生成的一维数组

      如果您有超过 3 列,请注意,这将找到“所有其他”列的最大值,而不是每行的“2 最大”列。例如:

      a2 = np.arange(25).reshape(5,5)
      np.dstack([np.max(np.delete(a2, i, 1), axis=1) for i in range(a2.shape[1])])
      
      array([[[ 4,  4,  4,  4,  3],
              [ 9,  9,  9,  9,  8],
              [14, 14, 14, 14, 13],
              [19, 19, 19, 19, 18],
              [24, 24, 24, 24, 23]]])
      

      【讨论】:

      • 不错,但我担心使用 np.delete 会太慢,而且任何列表理解都是一种 for 循环...
      • 两点都正确@fbparis。我相信保罗的解决方案在这里是无与伦比的
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-09-23
      • 1970-01-01
      • 2016-11-09
      • 2016-07-14
      相关资源
      最近更新 更多