【问题标题】:How to vectorize a simple for loop in Python/Numpy如何在 Python/Numpy 中矢量化一个简单的 for 循环
【发布时间】:2012-11-03 00:11:45
【问题描述】:

我发现了许多如何在 Python/NumPy 中矢量化 for 循环的示例。不幸的是,我不知道如何使用矢量化形式减少简单 for 循环的计算时间。在这种情况下甚至可能吗?

time = np.zeros(185000)
lat1 = np.array(([48.78,47.45],[38.56,39.53],...)) # ~ 200000 rows
lat2 = np.array(([7.78,5.45],[7.56,5.53],...)) # same number of rows as time
for ii in np.arange(len(time)):
    pos = np.argwhere( (lat1[:,0]==lat2[ii,0]) and \
                       (lat1[:,1]==lat2[ii,1]) )
    if pos.size:
        pos = int(pos)
        time[ii] = dtime[pos]

【问题讨论】:

  • latlontime 是什么?特别是它们的形状是什么?
  • 我更新了上面的示例值。
  • 您能解释一下pos = np.argwhere( (lat1[:,0]==lat2[ii,0]) and (lat1[:,1]==lat2[ii,1]) ) 的含义吗?那么,你想在 lat2 中找到这样一个等于 lat1 的行吗?你不害怕浮点舍入错误吗?如果是这样,您可以在 lat2 上使用二进制搜索(在其排序副本中搜索)
  • 我正在寻找两列中 lat1 和 lat2 相等的行。在这种情况下,我需要 lat1 和 lat2 的行号。目前“ii”和“pos”给了我这个,它有效。我在两个数组上都使用了 np.around(XX,decimals=2) 以避免舍入错误。
  • 所以,如果lat1 = [[1,2], [3,4], [5,6], [7,8]]lat2 = [[3,4], [5,6], [7,8], [1,2]] 那么算法的结果应该是[1, 2, 3, 0](lat2 的 0-st 元素在 lat1 的 1-st 位置,lat2 的 1 个元素是on 2, 2 on 3, 3 on 0) 这是你想要的吗?

标签: python numpy scipy vectorization


【解决方案1】:

这里有一个解决方案。我不确定是否可以对其进行矢量化。如果你想让它抵抗“浮动比较错误”,你应该修改is_lessis_greater。 整个算法只是一个二分搜索。

import numpy as np

#lexicographicaly compare two points - a and b

def is_less(a, b):
    i = 0
    while i<len(a):
        if a[i]<b[i]:
            return True
        else:
            if a[i]>b[i]:
                return False
        i+=1
    return False

def is_greater(a, b):
    i = 0
    while i<len(a):
        if a[i]>b[i]:
            return True
        else:
            if a[i]<b[i]:
                return False
        i+=1
    return False


def binary_search(a, x, lo=0, hi=None):
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo+hi)//2
        midval = a[mid]
        if is_less(midval, x):
            lo = mid+1
        elif is_greater(midval, x):
            hi = mid
        else:
            return mid
    return -1

def lex_sort(v): #sort by 1 and 2 column respectively
    #return v[np.lexsort((v[:,2],v[:,1]))]
    order = range(1, v.shape[1])
    return v[np.lexsort(tuple(v[:,i] for i in order[::-1]))]

def sort_and_index(arr):
    ind = np.indices((len(arr),)).reshape((len(arr), 1))
    arr = np.hstack([ind, arr]) # add an index column as first column
    arr = lex_sort(arr)
    arr_cut = arr[:,1:] # an array to do binary search in
    arr_ind = arr[:,:1] # shuffled indices
    return arr_ind, arr_cut


#lat1 = np.array(([1,2,3], [3,4,5], [5,6,7], [7,8,9])) # ~ 200000 rows
lat1 = np.arange(1,800001,1).reshape((200000,4))
#lat2 = np.array(([3,4,5], [5,6,7], [7,8,9], [1,2,3])) # same number of rows as time
lat2 = np.arange(101,800101,1).reshape((200000,4))

lat1_ind, lat1_cut = sort_and_index(lat1)

time_arr = np.zeros(200000)
import time
start = time.time()

for ii, elem in enumerate(lat2):
    pos = binary_search(lat1_cut, elem)
    if pos == -1:
        #Not found
        continue
    pos = lat1_ind[pos][0]
    #print "element in lat2 with index",ii,"has position",pos,"in lat1"
print time.time()-start

注释的打印是您拥有 lat1 和 lat2 对应索引的地方。在 200000 行上工作 7 秒。

【讨论】:

  • 确实,这非常快!但是,我得到了奇怪的结果..我现在正在检查为什么。如果我 lat1 和 lat2 有 4 列要比较,我需要更改什么?是否可以直接使用它们?
  • 此外,我根本不明白二进制搜索中发生了什么以及为什么这比我开始帖子中的简单循环要快得多。你有一些关于它的文献吗?
  • 好吧,假设你有一个排序数组[1, 2, 6, 17, 25, 29, 37],你想找到25元素的索引。一种方法是遍历整个数组并将每个元素与寻找的元素进行比较。但是我们的数组是排序的,因此通过查看任何索引处的元素,我们可以判断所查找的元素是在该元素的右侧还是左侧。例如让我们从数组的中心开始:当前元素是17,索引是3(从0开始计数)。我们正在寻找2525&gt;17,因此在左侧(索引[0,3])中寻找我们的元素是没有意义的。
  • 因此我们可以代替遍历整个数组执行以下智能过程:1)选择输入数组的中心 2)将当前元素与寻找的元素进行比较。 3)如果等于,那么我们发现否则有两种情况:seek 较小,seek 较大。如果较小,则设置输入数组 = 输入数组的左半部分并转到 1)。如果更大,则相同但右半部分。该算法允许我们在 O(log2(N)) 而不是 O(N) 时间内找到索引。比较:N = 200000 log2(N) = 17.609640474436812。这大致意味着通过二分查找,您将在最多 17 次操作中找到元素。
  • 用朴素的搜索算法代替 200000 次操作。这里有一个很好的解释community.topcoder.com/…
【解决方案2】:

找到所有匹配项的最快方法可能是对两个数组进行排序并一起遍历它们,就像这个工作示例:

import numpy as np

def is_less(a, b):
    # this ugliness is needed because we want to compare lexicographically same as np.lexsort(), from the last column backward
    for i in range(len(a)-1, -1, -1):
        if a[i]<b[i]: return True
        elif a[i]>b[i]: return False
    return False

def is_equal(a, b):
    for i in range(len(a)):
        if a[i] != b[i]: return False
    return True

# lat1 = np.array(([48.78,47.45],[38.56,39.53]))
# lat2 = np.array(([7.78,5.45],[48.78,47.45],[7.56,5.53]))
lat1 = np.load('arr.npy')
lat2 = np.load('refarr.npy')

idx1 = np.lexsort( lat1.transpose() )
idx2 = np.lexsort( lat2.transpose() )
ii = 0
jj = 0
while ii < len(idx1) and jj < len(idx2):
    a = lat1[ idx1[ii] , : ]
    b = lat2[ idx2[jj] , : ]
    if is_equal( a, b ):
        # do stuff with match
        print "match found: lat1=%s lat2=%s %d and %d" % ( repr(a), repr(b), idx1[ii], idx2[jj] )
        ii += 1
        jj += 1
    elif is_less( a, b ):
        ii += 1
    else:
        jj += 1

这可能不是完美的 Python 语言(也许有人可以想到使用生成器或 itertools 的更好实现?)但很难想象任何依赖一次搜索一个点的方法会在速度上超越这一点。

【讨论】:

  • 不幸的是,这不适用于我的数据:dropbox.com/s/xs35kvoexi85bk0/arr.npydropbox.com/s/g4bdk509bvwzn3u/refarr.npy。 arr 在您的代码中应该是 lat1 和 refarr lat2。
  • @HyperCube,请尝试上面的最新代码。在您的示例数据(包括 np.load()s)上工作 0.4 秒。我忘记了(明显的)索引取消引用 lat1[ idx1[ii], ... ] 等等;它恰好适用于我使用的数据。
  • 效果很好!完整数据集在我的机器上需要 5.2 秒,而 alex_jordan 的解决方案需要 10.3 秒
  • 太棒了! O(N) 而不是我的 O(N*log(N))。我怎么没想到。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2011-02-09
  • 1970-01-01
  • 2021-09-10
  • 2018-10-15
  • 2016-05-26
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多