【问题标题】:Optimized method for calculating cosine distance in PythonPython中计算余弦距离的优化方法
【发布时间】:2010-12-21 20:26:42
【问题描述】:

我写了一个方法来计算两个数组之间的余弦距离:

def cosine_distance(a, b):
    if len(a) != len(b):
        return False
    numerator = 0
    denoma = 0
    denomb = 0
    for i in range(len(a)):
        numerator += a[i]*b[i]
        denoma += abs(a[i])**2
        denomb += abs(b[i])**2
    result = 1 - numerator / (sqrt(denoma)*sqrt(denomb))
    return result

在大型阵列上运行它可能会非常慢。此方法是否有运行速度更快的优化版本?

更新:我已经尝试了迄今为止的所有建议,包括 scipy。这是要击败的版本,结合了 Mike 和 Steve 的建议:

def cosine_distance(a, b):
    if len(a) != len(b):
        raise ValueError, "a and b must be same length" #Steve
    numerator = 0
    denoma = 0
    denomb = 0
    for i in range(len(a)):       #Mike's optimizations:
        ai = a[i]             #only calculate once
        bi = b[i]
        numerator += ai*bi    #faster than exponent (barely)
        denoma += ai*ai       #strip abs() since it's squaring
        denomb += bi*bi
    result = 1 - numerator / (sqrt(denoma)*sqrt(denomb))
    return result

【问题讨论】:

  • a 和 b 数组是复数吗?
  • 到目前为止我已经尝试了所有的建议,目前 Mike Dunlavey 的精简现有代码的建议已经给出了最好的结果。我想我会留下这个问题,以防有其他解决问题的策略,但大多数建议最终实际上比原始代码运行得慢,所以 Python 必须在动态优化方面做得很好。还有@gnibbler,我没有使用任何复数。
  • 我不明白你为什么要先练腹肌再练。
  • 我刚刚进行了一个快速测试,当列表大约有 1000 个元素时,使用 numpy 会更快。
  • 小数组的 numpy 速度较慢的原因是转换为 numpy 数组的开销。

标签: python arrays optimization distance


【解决方案1】:

如果你可以使用 SciPy,你可以使用 cosine from spatial.distance:

http://docs.scipy.org/doc/scipy/reference/spatial.distance.html

如果你不能使用 SciPy,你可以尝试通过重写你的 Python 来获得一个小的加速(编辑:但它并没有像我想象的那样工作,见下文)。

from itertools import izip
from math import sqrt

def cosine_distance(a, b):
    if len(a) != len(b):
        raise ValueError, "a and b must be same length"
    numerator = sum(tup[0] * tup[1] for tup in izip(a,b))
    denoma = sum(avalue ** 2 for avalue in a)
    denomb = sum(bvalue ** 2 for bvalue in b)
    result = 1 - numerator / (sqrt(denoma)*sqrt(denomb))
    return result

最好在 a 和 b 的长度不匹配时引发异常。

通过在对sum() 的调用中使用生成器表达式,您可以使用 Python 内部的 C 代码完成的大部分工作来计算您的值。这应该比使用for 循环更快。

我没有计时,所以我无法猜测它可能会快多少。但几乎可以肯定,SciPy 代码是用 C 或 C++ 编写的,并且应该尽可能快。

如果您正在使用 Python 进行生物信息学,那么您确实应该使用 SciPy。

编辑:Darius Bacon 为我的代码计时,发现它变慢了。所以我给我的代码计时了......是的,它更慢。所有人的教训:当你试图加快速度时,不要猜测,要衡量。

我很困惑为什么我在 Python 的 C 内部进行更多工作的尝试速度较慢。我尝试了长度为 1000 的列表,但速度仍然较慢。

我不能再花时间尝试巧妙地破解 Python。如果你需要更快的速度,我建议你试试 SciPy。

编辑:我只是手动测试,没有时间。我发现简而言之 a 和 b,旧代码更快;对于long a和b,新代码更快;在这两种情况下,差异都不大。 (我现在想知道我是否可以在我的 Windows 计算机上信任 timeit;我想在 Linux 上再次尝试这个测试。)我不会更改工作代码来尝试让它更快。还有一次,我敦促您尝试 SciPy。 :-)

【讨论】:

  • 分子行不正确:它执行嵌套循环而不是并行循环。
  • 另外,当我修复该行以获得正确答案时,它仍然比原始代码慢。无论如何都同意 SciPy! (numerator = sum(avalue * bvalue for avalue, bvalue in zip(a, b)))
  • 与 SciPy 的良好通话。不幸的是,您的非 SciPy 重写返回了错误的值。用 gnibbler 的结果替换分子行得到正确答案,但它实际上比我的原始代码慢得多。
  • 有趣的是,scipy 实际上要慢得多。为了测试,我通过 100K 迭代运行了几个小数组。原始代码运行约 1.3 秒,scipy 运行约 7.5 秒。我想知道这些表是否会打开更大的数组?
  • 出于好奇,因为我没有安装 SciPy,但我一直对这个项目很感兴趣,你有什么时间使用来自 spatial.distance 的余弦来处理这个特殊案例吗?
【解决方案2】:

(我最初认为)如果不使用 C(如 numpy 或 scipy)或更改您的计算内容,您将不会加快速度。但无论如何,这就是我要尝试的方法:

from itertools import imap
from math import sqrt
from operator import mul

def cosine_distance(a, b):
    assert len(a) == len(b)
    return 1 - (sum(imap(mul, a, b))
                / sqrt(sum(imap(mul, a, a))
                       * sum(imap(mul, b, b))))

在具有 500k 元素数组的 Python 2.6 中,它的速度大约是后者的两倍。 (将地图更改为 imap 后,跟随 Jarret Hardie。)

这是原始海报修改代码的调整版本:

from itertools import izip

def cosine_distance(a, b):
    assert len(a) == len(b)
    ab_sum, a_sum, b_sum = 0, 0, 0
    for ai, bi in izip(a, b):
        ab_sum += ai * bi
        a_sum += ai * ai
        b_sum += bi * bi
    return 1 - ab_sum / sqrt(a_sum * b_sum)

它很丑,但它确实出来得更快。 . .

编辑:试试Psyco!它将最终版本的速度提高了 4 倍。我怎么会忘记?

【讨论】:

  • 不错的补充 - 很高兴听到使用 imap 提供了 mul 优于 ** 2 的优势
  • 我觉得没那么丑:p
  • 看到命令式代码击败了更直接地表达问题的纯函数式代码,我有点懊恼。
【解决方案3】:

如果您要对其进行平方,则无需使用abs() 中的a[i]b[i]

a[i]b[i] 存储在临时变量中,以避免多次进行索引。 也许编译器可以对此进行优化,但也许不行。

检查**2 运算符。是将其简化为乘法,还是使用通用幂函数(对数 - 乘以 2 - 反对数)。

不要执行 sqrt 两次(尽管这样做的成本很小)。做sqrt(denoma * denomb)

【讨论】:

  • 好电话...每个都节省了一点时间。
  • @Dan:欢迎。接下来我会看看一些展开是否会有所帮助,以防迭代器花费你(他们倾向于这样做)。接下来我会做一些堆栈采样,看看函数是否被调用超过必要(或者是否有任何其他未被注意到的时间肿瘤)。
【解决方案4】:

与 Darius Bacon 的答案类似,我一直在玩弄 operator 和 itertools 以产生更快的答案。根据 timeit,以下似乎在 500 项数组上快 1/3:

from math import sqrt
from itertools import imap
from operator import mul

def op_cosine(a, b):
    dot_prod = sum(imap(mul, a, b))
    a_veclen = sqrt(sum(i ** 2 for i in a))
    b_veclen = sqrt(sum(i ** 2 for i in b))

    return 1 - dot_prod / (a_veclen * b_veclen)

【讨论】:

    【解决方案5】:

    这对于大约 1000 多个元素的数组来说更快。

    from numpy import array
    def cosine_distance(a, b):
        a=array(a)
        b=array(b)
        numerator=(a*b).sum()
        denoma=(a*a).sum()
        denomb=(b*b).sum()
        result = 1 - numerator / sqrt(denoma*denomb)
        return result
    

    【讨论】:

      【解决方案6】:

      在 SciPy 中使用 C 代码对长输入数组大有裨益。对短输入数组使用简单直接的 Python 获胜; Darius Bacon 的基于izip() 的代码进行了最佳基准测试。因此,最终的解决方案是在运行时根据输入数组的长度来决定使用哪一个:

      from scipy.spatial.distance import cosine as scipy_cos_dist
      
      from itertools import izip
      from math import sqrt
      
      def cosine_distance(a, b):
          len_a = len(a)
          assert len_a == len(b)
          if len_a > 200:  # 200 is a magic value found by benchmark
              return scipy_cos_dist(a, b)
          # function below is basically just Darius Bacon's code
          ab_sum = a_sum = b_sum = 0
          for ai, bi in izip(a, b):
              ab_sum += ai * bi
              a_sum += ai * ai
              b_sum += bi * bi
          return 1 - ab_sum / sqrt(a_sum * b_sum)
      

      我制作了一个测试工具来测试具有不同长度输入的函数,发现在长度 200 左右 SciPy 函数开始获胜。输入数组越大,它就越大。对于长度很短的数组,比如长度为 3,更简单的代码会胜出。此函数会增加少量开销来决定采用哪种方式,然后采用最佳方式。

      如果您有兴趣,这里是测试工具:

      from darius2 import cosine_distance as fn_darius2
      fn_darius2.__name__ = "fn_darius2"
      
      from ult import cosine_distance as fn_ult
      fn_ult.__name__ = "fn_ult"
      
      from scipy.spatial.distance import cosine as fn_scipy
      fn_scipy.__name__ = "fn_scipy"
      
      import random
      import time
      
      lst_fn = [fn_darius2, fn_scipy, fn_ult]
      
      def run_test(fn, lst0, lst1, test_len):
          start = time.time()
          for _ in xrange(test_len):
              fn(lst0, lst1)
          end = time.time()
          return end - start
      
      for data_len in range(50, 500, 10):
          a = [random.random() for _ in xrange(data_len)]
          b = [random.random() for _ in xrange(data_len)]
          print "len(a) ==", len(a)
          test_len = 10**3
          for fn in lst_fn:
              n = fn.__name__
              r = fn(a, b)
              t = run_test(fn, a, b, test_len)
              print "%s:\t%f seconds, result %f" % (n, t, r)
      

      【讨论】:

        【解决方案7】:
        def cd(a,b):
            if(len(a)!=len(b)):
                raise ValueError, "a and b must be the same length"
            rn = range(len(a))
            adb = sum([a[k]*b[k] for k in rn])
            nma = sqrt(sum([a[k]*a[k] for k in rn]))
            nmb = sqrt(sum([b[k]*b[k] for k in rn]))
        
            result = 1 - adb / (nma*nmb)
            return result
        

        【讨论】:

        • 您在对sum() 的调用中使用列表推导。这将创建一个列表,然后sum() 将使用该列表一次,然后该列表将被垃圾收集。 Python 有一个名为“生成器表达式”的漂亮特性,您可以在其中使用与列表推导式相同的语法,但它会创建一个迭代器。如果您只是从对sum() 的调用中删除[],您现在将使用生成器表达式。在此处阅读更多信息:docs.python.org/howto/…
        • @steveha:取决于输入长度和功能。我不知道这里,但是 str.join(..) 对于短输入(len ~100)来说,列表理解比 genexps 更快。
        • @kaizer.se: str.join 是一个特殊情况,因为当它有一个列表时,它首先对镜头求和,然后创建一个总大小的字符串并用零件填充它;否则,它会以明显的方式构建字符串(对于可迭代的部分:结果+=部分)
        【解决方案8】:

        您更新后的解决方案仍然有两个平方根。您可以将 sqrt 行替换为:

        结果 = 1 - 分子 / (sqrt(denoma*denomb))

        乘法通常比 sqrt 快很多。它可能看起来并不多,因为它只在函数中调用了一次,但听起来你正在计算很多余弦距离,所以改进会加起来。

        您的代码看起来应该适合向量优化。因此,如果跨平台支持不是问题并且您想进一步加快速度,您可以在 C 中编写余弦距离代码,并确保您的编译器积极地矢量化生成的代码(即使 Pentium II 也能够进行一些浮点矢量化)

        【讨论】:

          猜你喜欢
          • 2014-08-21
          • 1970-01-01
          • 2020-01-14
          • 1970-01-01
          • 2020-06-06
          • 2016-11-29
          • 1970-01-01
          • 2017-09-15
          • 2017-11-05
          相关资源
          最近更新 更多