【问题标题】:zip iterators asserting for equal length in pythonzip迭代器在python中断言相等的长度
【发布时间】:2016-01-02 10:58:06
【问题描述】:

我正在寻找一种很好的方法来zip 几个迭代器,如果迭代器的长度不相等,则会引发异常。

在可迭代对象是列表或具有len 方法的情况下,此解决方案简洁明了:

def zip_equal(it1, it2):
    if len(it1) != len(it2):
        raise ValueError("Lengths of iterables are different")
    return zip(it1, it2)

但是,如果it1it2 是生成器,则前面的函数会失败,因为未定义长度TypeError: object of type 'generator' has no len()

我想itertools 模块提供了一种简单的方法来实现它,但到目前为止我还没有找到它。我想出了这个自制的解决方案:

def zip_equal(it1, it2):
    exhausted = False
    while True:
        try:
            el1 = next(it1)
            if exhausted: # in a previous iteration it2 was exhausted but it1 still has elements
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            exhausted = True
            # it2 must be exhausted too.
        try:
            el2 = next(it2)
            # here it2 is not exhausted.
            if exhausted:  # it1 was exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            # here it2 is exhausted
            if not exhausted:
                # but it1 was not exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
            exhausted = True
        if not exhausted:
            yield (el1, el2)
        else:
            return

可以使用以下代码测试解决方案:

it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it1, it2))           # len(it1) < len(it2) => raise
it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it2, it1))           # len(it2) > len(it1) => raise
it1 = (x for x in ['a', 'b', 'c', 'd'])  # it1 has length 4
it2 = (x for x in [0, 1, 2, 3])          # it2 has length 4
list(zip_equal(it1, it2))                # like zip (or izip in python2)

我是否忽略了任何替代解决方案?我的zip_equal 函数有更简单的实现吗?

更新:

  • 需要 python 3.10 或更高版本,请参阅 Asocia 的 answer
  • pythonanswer
  • 没有外部依赖的简单答案: Martijn Pieters' answer(请检查 cmets 以获取某些极端情况下的错误修复)
  • 比 Martijn 的更复杂,但性能更好:cjerdonek 的 answer
  • 如果您不介意包依赖,请参阅 pylang的answer

【问题讨论】:

  • 这个问题(我猜它的答案)被 PEP 618 -- Add Optional Length-Checking To zip 顺便引用了,它把它带到了 Python 3.10。作为证据表明“手写一个能够做到这一点的强大解决方案并非易事” :-)

标签: python itertools


【解决方案1】:

PEP 618 中的内置zip 函数引入了可选的布尔关键字参数strict

引用What’s New In Python 3.10:

zip() 函数现在有一个可选的strict 标志,用于要求所有可迭代对象的长度相同。

启用后,如果其中一个参数在其他参数之前用尽,则会引发 ValueError

>>> list(zip('ab', range(3)))
[('a', 0), ('b', 1)]
>>> list(zip('ab', range(3), strict=True))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: zip() argument 2 is longer than argument 1

【讨论】:

    【解决方案2】:

    我可以想到一个更简单的解决方案,使用 itertools.zip_longest() 并在生成的元组中存在用于填充较短可迭代对象的标记值时引发异常:

    from itertools import zip_longest
    
    def zip_equal(*iterables):
        sentinel = object()
        for combo in zip_longest(*iterables, fillvalue=sentinel):
            if sentinel in combo:
                raise ValueError('Iterables have different lengths')
            yield combo
    

    不幸的是,我们不能使用 zip()yield from 来避免 Python 代码循环,每次迭代都要进行一次测试;一旦最短的迭代器用完,zip() 将推进所有前面的迭代器,因此如果其中只有一个额外的项目,则会吞下证据。

    【讨论】:

    • yield from 解决方案非常好。感谢您提供两种不同的解决方案。
    • 哦,另外一件事,第二种解决方案在一种极端情况下不起作用:假设有两个迭代器,而第二个迭代器短了一个。因为 zip 已经在第一个迭代器上调用了__next__,即使第一个迭代器会更长,两者都已用尽。
    • 另外,将if sentinel in combo 替换为if any(sentinel is c for c in combo) 可能会更好,因为a in bany(bi==a for b_ in b) 相同 - 有时会覆盖等号(当组合中的元素为numpy 数组)
    • @Peter 实际上sentinel in combo 检查both 身份 相等性(至少现在:-),但是是的,如果有一个元素那就错了声称等于这个哨兵,numpy 数组甚至可以使这个崩溃,就像np.ones(2) in [object()] 一样。 more-itertools 函数实际上像 Martijn 一样工作,但只在循环中检查身份。这可能是它变慢的主要原因,请参阅benchmark in my answer
    • @zeehio:这很公平,他的解决方案很棒!
    【解决方案3】:

    这是一种不需要对迭代的每个循环进行任何额外检查的方法。这可能是可取的,尤其是对于长迭代。

    这个想法是在每个可迭代对象的末尾填充一个“值”,当到达时会引发异常,然后仅在最后进行所需的验证。该方法使用zip()itertools.chain()

    以下代码是为 Python 3.5 编写的。

    import itertools
    
    class ExhaustedError(Exception):
        def __init__(self, index):
            """The index is the 0-based index of the exhausted iterable."""
            self.index = index
    
    def raising_iter(i):
        """Return an iterator that raises an ExhaustedError."""
        raise ExhaustedError(i)
        yield
    
    def terminate_iter(i, iterable):
        """Return an iterator that raises an ExhaustedError at the end."""
        return itertools.chain(iterable, raising_iter(i))
    
    def zip_equal(*iterables):
        iterators = [terminate_iter(*args) for args in enumerate(iterables)]
        try:
            yield from zip(*iterators)
        except ExhaustedError as exc:
            index = exc.index
            if index > 0:
                raise RuntimeError('iterable {} exhausted first'.format(index)) from None
            # Check that all other iterators are also exhausted.
            for i, iterator in enumerate(iterators[1:], start=1):
                try:
                    next(iterator)
                except ExhaustedError:
                    pass
                else:
                    raise RuntimeError('iterable {} is longer'.format(i)) from None
    

    下面是使用的样子。

    >>> list(zip_equal([1, 2], [3, 4], [5, 6]))
    [(1, 3, 5), (2, 4, 6)]
    
    >>> list(zip_equal([1, 2], [3], [4]))
    RuntimeError: iterable 1 exhausted first
    
    >>> list(zip_equal([1], [2, 3], [4]))
    RuntimeError: iterable 1 is longer
    
    >>> list(zip_equal([1], [2], [3, 4]))
    RuntimeError: iterable 2 is longer
    

    【讨论】:

    • 我更喜欢这种方法。它比公认的答案要复杂一些,但它使用 EAFP 而不是 LBYL,并且还提供了更好的错误消息。太棒了。
    • 我通过简短的讨论编辑了我的问题,当性能成为问题时,该讨论指出了您的答案。感谢您的解决方案!
    • 这种方法的唯一问题是,即使所有迭代器的长度相同,它也会产生不必要的异常。
    • 您可以通过 直接 压缩 first 可迭代对象来加快速度,而无需链接 raising-iter。也就是说,在terminate_iter 中做if i == 0: return iterable。然后在except ExhaustedError 内无条件地加注,并取消缩进Check that all other 第一部分。也就是说,我现在用my solution 迈出了更大的一步。
    【解决方案4】:

    使用more_itertools.zip_equal (v8.3.0+):

    代码

    import more_itertools as mit
    

    演示

    list(mit.zip_equal(range(3), "abc"))
    # [(0, 'a'), (1, 'b'), (2, 'c')]
    
    list(mit.zip_equal(range(3), "abcd"))
    # UnequalIterablesError
    

    more_itertools是通过λ pip install more_itertools安装的第三方包

    【讨论】:

      【解决方案5】:

      比它所基于的 cjerdonek 更快的新解决方案和基准。首先进行基准测试,我的解决方案是绿色的。请注意,“总大小”在所有情况下都是相同的,即 200 万个值。 x 轴是 iterables 的数量。从 1 个具有 200 万个值的迭代,然后是 2 个每个具有 100 万个值的迭代,一直到 100,000 个每个具有 20 个值的迭代。

      黑色的是 Python 的 zip,我在这里使用了 Python 3.8,因此它不执行检查相等长度的这个问题的任务,但我将它作为参考/限制的最大速度可以希望。你可以看到我的解决方案非常接近。

      对于压缩 两个 迭代的最常见情况,我的速度几乎是 cjerdonek 以前最快的解决方案的三倍,并且不比 zip 慢多少。时间作为文本:

               number of iterables     1     2     3     4     5    10   100  1000 10000 50000 100000
      -----------------------------------------------------------------------------------------------
             more_itertools__pylang 209.3 132.1 105.8  93.7  87.4  74.4  54.3  51.9  53.9  66.9  84.5
         fillvalue__Martijn_Pieters 159.1 101.5  85.6  74.0  68.8  59.0  44.1  43.0  44.9  56.9  72.0
           chain_raising__cjerdonek  58.5  35.1  26.3  21.9  19.7  16.6  10.4  12.7  34.4 115.2 223.2
           ziptail__Stefan_Pochmann  10.3  12.4  10.4   9.2   8.7   7.8   6.7   6.8   9.4  22.6  37.8
                                zip  10.3   8.5   7.8   7.4   7.4   7.1   6.4   6.8   9.0  19.4  32.3
      

      我的代码 (Try it online!):

      def zip_equal(*iterables):
      
          # For trivial cases, use pure zip.
          if len(iterables) < 2:
              return zip(*iterables)
      
          # Tail for the first iterable
          first_stopped = False
          def first_tail():
              nonlocal first_stopped 
              first_stopped = True
              return
              yield
      
          # Tail for the zip
          def zip_tail():
              if not first_stopped:
                  raise ValueError('zip_equal: first iterable is longer')
              for _ in chain.from_iterable(rest):
                  raise ValueError('zip_equal: first iterable is shorter')
                  yield
      
          # Put the pieces together
          iterables = iter(iterables)
          first = chain(next(iterables), first_tail())
          rest = list(map(iter, iterables))
          return chain(zip(first, *rest), zip_tail())
      

      基本思路是让zip(*iterables)做所有的工作,然后在它因为某个iterable耗尽而停止之后,检查所有的iterable是否都一样长。他们是当且仅当:

      1. zip 停止,因为 first 可迭代没有其他元素(即,没有其他可迭代更短)。
      2. 其他可迭代对象都没有任何其他元素(即,没有其他可迭代对象更长)。

      我如何检查这些标准:

      • 由于我需要在zip 结束后检查这些条件,所以我不能纯粹返回zip 对象。相反,我在它后面链接了一个空的 zip_tail 迭代器来进行检查。
      • 为了支持检查第一个标准,我在它后面链接了一个空的first_tail 迭代器,它的唯一工作是记录第一个可迭代对象的迭代停止(即,它被要求提供另一个元素但它没有,所以first_tail 迭代器被要求提供一个)。
      • 为了支持检查第二个条件,我获取所有其他可迭代对象的迭代器并将它们保存在一个列表中,然后再将它们提供给 zip

      旁注:more-itertools 几乎使用与 Martijn 相同的方法,但使用正确的 is 检查而不是 Martijn 的 not quite correct sentinel in combo。这可能是它变慢的主要原因。

      基准代码 (Try it online!):

      import timeit
      import itertools
      from itertools import repeat, chain, zip_longest
      from collections import deque
      from sys import hexversion, maxsize
      
      #-----------------------------------------------------------------------------
      # Solution by Martijn Pieters
      #-----------------------------------------------------------------------------
      
      def zip_equal__fillvalue__Martijn_Pieters(*iterables):
          sentinel = object()
          for combo in zip_longest(*iterables, fillvalue=sentinel):
              if sentinel in combo:
                  raise ValueError('Iterables have different lengths')
              yield combo
      
      #-----------------------------------------------------------------------------
      # Solution by pylang
      #-----------------------------------------------------------------------------
      
      def zip_equal__more_itertools__pylang(*iterables):
          return more_itertools__zip_equal(*iterables)
      
      _marker = object()
      
      def _zip_equal_generator(iterables):
          for combo in zip_longest(*iterables, fillvalue=_marker):
              for val in combo:
                  if val is _marker:
                      raise UnequalIterablesError()
              yield combo
      
      def more_itertools__zip_equal(*iterables):
          """``zip`` the input *iterables* together, but raise
          ``UnequalIterablesError`` if they aren't all the same length.
      
              >>> it_1 = range(3)
              >>> it_2 = iter('abc')
              >>> list(zip_equal(it_1, it_2))
              [(0, 'a'), (1, 'b'), (2, 'c')]
      
              >>> it_1 = range(3)
              >>> it_2 = iter('abcd')
              >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
              Traceback (most recent call last):
              ...
              more_itertools.more.UnequalIterablesError: Iterables have different
              lengths
      
          """
          if hexversion >= 0x30A00A6:
              warnings.warn(
                  (
                      'zip_equal will be removed in a future version of '
                      'more-itertools. Use the builtin zip function with '
                      'strict=True instead.'
                  ),
                  DeprecationWarning,
              )
          # Check whether the iterables are all the same size.
          try:
              first_size = len(iterables[0])
              for i, it in enumerate(iterables[1:], 1):
                  size = len(it)
                  if size != first_size:
                      break
              else:
                  # If we didn't break out, we can use the built-in zip.
                  return zip(*iterables)
      
              # If we did break out, there was a mismatch.
              raise UnequalIterablesError(details=(first_size, i, size))
          # If any one of the iterables didn't have a length, start reading
          # them until one runs out.
          except TypeError:
              return _zip_equal_generator(iterables)
      
      #-----------------------------------------------------------------------------
      # Solution by cjerdonek
      #-----------------------------------------------------------------------------
      
      class ExhaustedError(Exception):
          def __init__(self, index):
              """The index is the 0-based index of the exhausted iterable."""
              self.index = index
      
      def raising_iter(i):
          """Return an iterator that raises an ExhaustedError."""
          raise ExhaustedError(i)
          yield
      
      def terminate_iter(i, iterable):
          """Return an iterator that raises an ExhaustedError at the end."""
          return itertools.chain(iterable, raising_iter(i))
      
      def zip_equal__chain_raising__cjerdonek(*iterables):
          iterators = [terminate_iter(*args) for args in enumerate(iterables)]
          try:
              yield from zip(*iterators)
          except ExhaustedError as exc:
              index = exc.index
              if index > 0:
                  raise RuntimeError('iterable {} exhausted first'.format(index)) from None
              # Check that all other iterators are also exhausted.
              for i, iterator in enumerate(iterators[1:], start=1):
                  try:
                      next(iterator)
                  except ExhaustedError:
                      pass
                  else:
                      raise RuntimeError('iterable {} is longer'.format(i)) from None
                  
      #-----------------------------------------------------------------------------
      # Solution by Stefan Pochmann
      #-----------------------------------------------------------------------------
      
      def zip_equal__ziptail__Stefan_Pochmann(*iterables):
      
          # For trivial cases, use pure zip.
          if len(iterables) < 2:
              return zip(*iterables)
      
          # Tail for the first iterable
          first_stopped = False
          def first_tail():
              nonlocal first_stopped 
              first_stopped = True
              return
              yield
      
          # Tail for the zip
          def zip_tail():
              if not first_stopped:
                  raise ValueError(f'zip_equal: first iterable is longer')
              for _ in chain.from_iterable(rest):
                  raise ValueError(f'zip_equal: first iterable is shorter')
                  yield
      
          # Put the pieces together
          iterables = iter(iterables)
          first = chain(next(iterables), first_tail())
          rest = list(map(iter, iterables))
          return chain(zip(first, *rest), zip_tail())
      
      #-----------------------------------------------------------------------------
      # List of solutions to be speedtested
      #-----------------------------------------------------------------------------
      
      solutions = [
          zip_equal__more_itertools__pylang,
          zip_equal__fillvalue__Martijn_Pieters,
          zip_equal__chain_raising__cjerdonek,
          zip_equal__ziptail__Stefan_Pochmann,
          zip,
      ]
      
      def name(solution):
          return solution.__name__[11:] or 'zip'
      
      #-----------------------------------------------------------------------------
      # The speedtest code
      #-----------------------------------------------------------------------------
      
      def test(m, n):
          """Speedtest all solutions with m iterables of n elements each."""
      
          all_times = {solution: [] for solution in solutions}
          def show_title():
              print(f'{m} iterators of length {n:,}:')
          if verbose: show_title()
          def show_times(times, solution):
              print(*('%3d ms ' % t for t in times),
                    name(solution))
              
          for _ in range(3):
              for solution in solutions:
                  times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
                  times = [round(t * 1e3, 3) for t in times]
                  all_times[solution].append(times)
                  if verbose: show_times(times, solution)
              if verbose: print()
              
          if verbose:
              print('best by min:')
              show_title()
              for solution in solutions:
                  show_times(min(all_times[solution], key=min), solution)
              print('best by max:')
          show_title()
          for solution in solutions:
              show_times(min(all_times[solution], key=max), solution)
          print()
      
          stats.append((m,
                        [min(all_times[solution], key=min)
                         for solution in solutions]))
      
      #-----------------------------------------------------------------------------
      # Run the speedtest for several numbers of iterables
      #-----------------------------------------------------------------------------
      
      stats = []
      verbose = False
      total_elements = 2 * 10**6
      for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
          test(m, total_elements // m)
      
      #-----------------------------------------------------------------------------
      # Print the speedtest results for use in the plotting script
      #-----------------------------------------------------------------------------
      
      print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
      names = [name(solution) for solution in solutions]
      print(f'{names = }')
      print(f'{stats = }')
      

      绘图/表格代码(也是at Replit):

      import matplotlib.pyplot as plt
      
      names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
      stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]
      
      colors = {
          'more_itertools__pylang': 'm',
          'fillvalue__Martijn_Pieters': 'red',
          'chain_raising__cjerdonek': 'gold',
          'ziptail__Stefan_Pochmann': 'lime',
          'zip': 'black',
      }
      
      ns = [n for n, _ in stats]
      print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
      print('-' * 95)
      x = range(len(ns))
      for i, name in enumerate(names):
          ts = [min(tss[i]) for _, tss in stats]
          color = colors[name]
          if color:
              plt.plot(x, ts, '.-', color=color, label=name)
              print('%29s' % name, *('%5.1f' % t for t in ts))
      plt.xticks(x, ns, size=9)
      plt.ylim(0, 133)
      plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
      plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
      plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
      plt.legend(loc='upper center')
      #plt.show()
      plt.savefig('zip_equal_plot.png', dpi=200)
      

      【讨论】:

        【解决方案6】:

        我想出了一个使用哨兵可迭代 FYI 的解决方案:

        class _SentinelException(Exception):
            def __iter__(self):
                raise _SentinelException
        
        
        def zip_equal(iterable1, iterable2):
            i1 = iter(itertools.chain(iterable1, _SentinelException()))
            i2 = iter(iterable2)
            try:
                while True:
                    yield (next(i1), next(i2))
            except _SentinelException:  # i1 reaches end
                try:
                    next(i2)  # check whether i2 reaches end
                except StopIteration:
                    pass
                else:
                    raise ValueError('the second iterable is longer than the first one')
            except StopIteration: # i2 reaches end, as next(i1) has already been called, i1's length is bigger than i2
                raise ValueError('the first iterable is longger the second one.')
        

        【讨论】:

        • 这个解决方案相对于公认的解决方案有什么优势?
        • 只是一种替代解决方案。对我来说,因为我来自 C++ 世界,所以我不喜欢每个产量的“if sentinel in combo”检查。但就像我们身处 Python 世界一样,没有人关心性能。
        • 感谢您的回答,但如果您真的关心性能,您应该对其进行基准测试。您的解决方案慢了 80%。这是一个基准:gist.github.com/zeehio/cdf7d881cc7f612b2c853fbd3a18ccbe
        • 感谢您提供的基准测试。很抱歉造成误导。是的,它更慢,我应该早点考虑一下,因为 izip_longest 是原生的。
        • 对不起,如果我的回复很苛刻。感谢您的回答,我们比较了性能。我们发现接受的答案比其他解决方案更快。我们现在还有一种快速的方法来对任何未来的解决方案进行基准测试。我们现在比一周前知道的更多。 :-)
        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2016-10-27
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多