【问题标题】:numpy 'isin' performance improvementnumpy 'isin' 性能改进
【发布时间】:2022-03-31 11:43:45
【问题描述】:

我有一个包含 383milj 行的矩阵,我需要根据值列表 (index_to_remove) 过滤这个矩阵。此功能在 1 次迭代中执行多次。是否有更快的替代方案:

def remove_from_result(matrix, index_to_remove, inv=True):
    return matrix[np.isin(matrix, index_to_remove, invert=inv)]

【问题讨论】:

  • @MuhammadAhmad,我不认为setnp.isin 有任何优势。
  • 我是否误解了index_to_remove 不是索引列表,而是您的函数删除的 列表?
  • @tif,是的,这是正确的。 index_to_remove 是一个值列表,它们位于矩阵的某个位置。
  • @MuhammadAhmad,是的,输出是“大致等价的”,但方法肯定不是。
  • 根据我的实验,matrix[isin(matrix,to_remove,invert=True)] 已经比 [i for i in matrix.flat if i not in to_remove][filter(lambda i:i not in to_remove,matrix.flat)] 快了 >50 倍。这暗示留给优化的空间很小。您可以尝试通过并行化此代码或使用更快的 python 实现来提高性能,例如赛通。

标签: python performance numpy


【解决方案1】:

更快的实现

这是使用集合作为@Matt Messersmith 的列表理解解决方案的编译版本。它基本上是较慢的 np.isin 方法的替代品。我在index_to_remove 是标量值的情况下遇到了一些问题,并为此实现了一个单独的版本。

代码

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def in1d_vec_nb(matrix, index_to_remove):
  #matrix and index_to_remove have to be numpy arrays
  #if index_to_remove is a list with different dtypes this 
  #function will fail

  out=np.empty(matrix.shape[0],dtype=nb.boolean)
  index_to_remove_set=set(index_to_remove)

  for i in nb.prange(matrix.shape[0]):
    if matrix[i] in index_to_remove_set:
      out[i]=False
    else:
      out[i]=True

  return out

@nb.njit(parallel=True)
def in1d_scal_nb(matrix, index_to_remove):
  #matrix and index_to_remove have to be numpy arrays
  #if index_to_remove is a list with different dtypes this 
  #function will fail

  out=np.empty(matrix.shape[0],dtype=nb.boolean)
  for i in nb.prange(matrix.shape[0]):
    if (matrix[i] == index_to_remove):
      out[i]=False
    else:
      out[i]=True

  return out


def isin_nb(matrix_in, index_to_remove):
  #both matrix_in and index_to_remove have to be a np.ndarray
  #even if index_to_remove is actually a single number
  shape=matrix_in.shape
  if index_to_remove.shape==():
    res=in1d_scal_nb(matrix_in.reshape(-1),index_to_remove.take(0))
  else:
    res=in1d_vec_nb(matrix_in.reshape(-1),index_to_remove)

  return res.reshape(shape)

示例

data = np.array([[80,1,12],[160,2,12],[240,3,12],[80,4,11]])
test_elts= np.array((80))

data[isin_nb(data[:,0],test_elts),:]

时间

test_elts = np.arange(12345)
data=np.arange(1000*1000)

#The first call has compilation overhead of about 300ms
#which is not included in the timings
#remove_from_result:     52ms
#isin_nb:                1.59ms

【讨论】:

  • 它似乎不适用于多维矩阵,例如data = [[80,1,12],[160,2,12],[240,3,12],[80,4,11]],假设我不想删除第一列中包含 80 的所有行,它会给出一个 numba 错误
  • @Ward Set 不能只使用一个值,我可以纠正这个问题。因此,如果 index_to_remove 值之一位于指定位置或每次都位于第一个位置,那么基本上您想完全删除行?处理起来会有点不同。你只有二维数组还是 nd 数组?
  • 是的,它是一个矩阵,就像我在评论中显示的数据一样,我在其中得到例如 [80,160] 并知道要搜索矩阵中的哪一列,例如在 data[:,0] 和那么它应该离开矩阵[[240,3,12]]
【解决方案2】:

过滤函数的运行时间似乎是线性 w.r.t。您输入的大小matrix。请注意,使用set 进行列表解析过滤绝对是线性的,并且您的函数的运行速度大约是在我的机器上使用相同输入的列表解析过滤器的两倍。您还可以看到,如果将大小增加 X 倍,运行时间也会增加 X 倍:

In [84]: test_elts = np.arange(12345)

In [85]: test_elts_set = set(test_elts)

In [86]: %timeit remove_from_result(np.arange(1000*1000), test_elts)
10 loops, best of 3: 81.5 ms per loop

In [87]: %timeit [x for x in np.arange(1000*1000) if x not in test_elts_set]
1 loop, best of 3: 201 ms per loop

In [88]: %timeit remove_from_result(np.arange(1000*1000*2), test_elts)
10 loops, best of 3: 191 ms per loop

In [89]: %timeit [x for x in np.arange(1000*1000*2) if x not in test_elts_set]
1 loop, best of 3: 430 ms per loop

In [90]: %timeit remove_from_result(np.arange(1000*1000*10), test_elts)
1 loop, best of 3: 916 ms per loop

In [91]: %timeit [x for x in np.arange(1000*1000*10) if x not in test_elts_set]
1 loop, best of 3: 2.04 s per loop

In [92]: %timeit remove_from_result(np.arange(1000*1000*100), test_elts)
1 loop, best of 3: 12.4 s per loop

In [93]: %timeit [x for x in np.arange(1000*1000*100) if x not in test_elts_set]
1 loop, best of 3: 26.4 s per loop

对于过滤非结构化数据,就算法复杂性而言,这是尽可能快的,因为您必须触摸每个元素一次。没有比线性时间更好的了。 有几件事可能有助于提高性能:

  1. 如果您可以访问 pyspark 之类的东西(如果您愿意支付几美元,可以在 AWS 上使用 EMR),您可以更快地完成此操作。这个问题非常尴尬地平行。您可以将输入分成 K 个块,给每个工人需要过滤的项目和一个块,让每个工人过滤,然后在最后收集/合并。或者你甚至可以尝试使用multiprocessing,但是你必须小心内存(multiprocessing 类似于 C 的fork(),它会产生子进程,但是每个子进程都会克隆你当前的内存空间)。

  2. 如果您的数据具有 某种 结构(例如已排序),您可能会更聪明,并获得亚线性算法复杂性。例如,如果您需要从一个大的、已排序的数组中删除相对较少的项目,您可以只对要删除的每个项目进行 bin 搜索。这将在 O(m log n) 时间内运行,其中 m 是要删除的项目数,n 是大数组的大小。如果 m 相对较小(与 n 相比),那么您正在做生意,那么您将接近 O(log n)。还有更聪明的方法来处理这种特殊情况,但我选择了这个,因为它很容易解释。如果您对数据的分布有所了解,那么您可能能够比线性时间做得更好。

HTH。

【讨论】:

    【解决方案3】:

    如果可能,您应该将传入 np.isin 的数组转换为整数类型。除此之外,如果可能尝试使比较数组(第二个参数尽可能小),删除重复是其中一种方法。

    【讨论】:

      猜你喜欢
      • 2021-02-13
      • 1970-01-01
      • 2019-08-17
      • 2019-12-25
      • 2015-03-21
      • 2016-12-30
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多