【问题标题】:deep copy nested iterable (or improved itertools.tee for iterable of iterables)深拷贝嵌套可迭代(或改进的 itertools.tee 用于可迭代的迭代)
【发布时间】:2019-05-16 01:13:54
【问题描述】:

前言

我有一个测试,我正在使用嵌套的可迭代对象(嵌套的可迭代对象我的意思是只有可迭代对象作为元素的可迭代对象)。

作为测试级联考虑

from itertools import tee
from typing import (Any,
                    Iterable)


def foo(nested_iterable: Iterable[Iterable[Any]]) -> Any:
    ...


def test_foo(nested_iterable: Iterable[Iterable[Any]]) -> None:
    original, target = tee(nested_iterable)  # this doesn't copy iterators elements

    result = foo(target)

    assert is_contract_satisfied(result, original)


def is_contract_satisfied(result: Any,
                          original: Iterable[Iterable[Any]]) -> bool:
    ...

例如foo 可能是简单的身份函数

def foo(nested_iterable: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
    return nested_iterable

合同只是检查扁平化的可迭代对象是否具有相同的元素

from itertools import (chain,
                       starmap,
                       zip_longest)
from operator import eq
...
flatten = chain.from_iterable


def is_contract_satisfied(result: Iterable[Iterable[Any]],
                          original: Iterable[Iterable[Any]]) -> bool:
    return all(starmap(eq,
                       zip_longest(flatten(result), flatten(original),
                                   # we're assuming that ``object()``
                                   # will create some unique object
                                   # not presented in any of arguments
                                   fillvalue=object())))

但如果nested_iterable 的某些元素是一个迭代器,它可能会耗尽,因为tee 正在制作浅拷贝,而不是深拷贝,即对于给定的foois_contract_satisfied 下一条语句

>>> test_foo([iter(range(10))])

导致可预测的

Traceback (most recent call last):
  ...
    test_foo([iter(range(10))])
  File "...", line 19, in test_foo
    assert is_contract_satisfied(result, original)
AssertionError

问题

如何深度复制任意嵌套的可迭代对象?

注意

我知道copy.deepcopy function,但它不适用于文件对象。

【问题讨论】:

  • 您是否有任何理由反对简单地将嵌套迭代器具体化为嵌套列表?
  • @juanpa.arrivillaga: 是的,我正在编写一个可与​​任意迭代(有限和无限、用户定义或来自标准库)一起使用的库,并编写基于属性的测试

标签: python itertools iterable


【解决方案1】:

朴素的解决方案

简单的算法是

  1. 对原始嵌套迭代执行元素复制。
  2. 制作 n 元素副本的副本。
  3. 获取与每个独立副本相关的坐标。

可以像这样实现

from itertools import tee
from operator import itemgetter
from typing import (Any,
                    Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


def copy_nested_iterable(nested_iterable: Iterable[Iterable[Domain]],
                         *,
                         count: int = 2
                         ) -> Tuple[Iterable[Iterable[Domain]], ...]:
    def shallow_copy(iterable: Iterable[Domain]) -> Tuple[Iterable[Domain], ...]:
        return tee(iterable, count)

    copies = shallow_copy(map(shallow_copy, nested_iterable))
    return tuple(map(itemgetter(index), iterables)
                 for index, iterables in enumerate(copies))

优点:

  • 很容易阅读和解释。

缺点:

  • 如果我们想将我们的方法扩展到具有更高嵌套级别的可迭代对象(例如嵌套可迭代对象的可迭代对象等),这种方法看起来没有帮助。

我们可以做得更好。

改进的解决方案

如果我们查看 itertools.tee function documentation,它包含 Python 配方,在 functools.singledispatch decorator 的帮助下可以重写为

from collections import (abc,
                         deque)
from functools import singledispatch
from itertools import repeat
from typing import (Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


@functools.singledispatch
def copy(object_: Domain,
         *,
         count: int) -> Iterable[Domain]:
    raise TypeError('Unsupported object type: {type}.'
                    .format(type=type(object_)))

# handle general case
@copy.register(object)
# immutable strings represent a special kind of iterables
# that can be copied by simply repeating
@copy.register(bytes)
@copy.register(str)
# mappings cannot be copied as other iterables
# since they are iterable only by key
@copy.register(abc.Mapping)
def copy_object(object_: Domain,
                *,
                count: int) -> Iterable[Domain]:
    return itertools.repeat(object_, count)


@copy.register(abc.Iterable)
def copy_iterable(object_: Iterable[Domain],
                  *,
                  count: int = 2) -> Tuple[Iterable[Domain], ...]:
    iterator = iter(object_)
    # we are using `itertools.repeat` instead of `range` here
    # due to efficiency of the former
    # more info at
    # https://stackoverflow.com/questions/9059173/what-is-the-purpose-in-pythons-itertools-repeat/9098860#9098860
    queues = [deque() for _ in repeat(None, count)]

    def replica(queue: deque) -> Iterable[Domain]:
        while True:
            if not queue:
                try:
                    element = next(iterator)
                except StopIteration:
                    return
                element_copies = copy(element,
                                           count=count)
                for sub_queue, element_copy in zip(queues, element_copies):
                    sub_queue.append(element_copy)
            yield queue.popleft()

    return tuple(replica(queue) for queue in queues)

优点:

  • 处理更深层次的嵌套,甚至处理混合元素,如同一级别上的可迭代和不可迭代,
  • 可以针对用户定义的结构进行扩展(例如,用于制作它们的独立深层副本)。

缺点:

  • 可读性较差(但我们知道"practicality beats purity"),
  • 提供了一些与调度相关的开销(但没关系,因为它基于具有O(1) 复杂性的字典查找)。

测试

准备

让我们定义我们的嵌套迭代如下

nested_iterable = [range(10 ** index) for index in range(1, 7)]

由于迭代器的创建没有说明底层副本的性能,让我们定义迭代器耗尽的函数(描述为here

exhaust_iterable = deque(maxlen=0).extend

时间

使用timeit

import timeit

def naive(): exhaust_iterable(copy_nested_iterable(nested_iterable))

def improved(): exhaust_iterable(copy_iterable(nested_iterable))

print('naive approach:', min(timeit.repeat(naive)))
print('improved approach:', min(timeit.repeat(improved)))

我的笔记本电脑上有 Python 3.5.4 中的 Windows 10 x64

naive approach: 5.1863865
improved approach: 3.5602296000000013

内存

使用memory_profiler package

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.6 MiB     51.4 MiB       result = list(flatten(flatten(copy_nested_iterable(nested_iterable))))

对于“幼稚”的方法和

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.7 MiB     51.4 MiB       result = list(flatten(flatten(copy_iterable(nested_iterable))))

对于“改进”的一个。

注意:我制作了不同的脚本运行,因为一次运行它们不会具有代表性,因为第二个语句将重用之前创建的底层 int 对象。


结论

我们可以看到这两个函数具有相似的性能,但最后一个支持更深层次的嵌套并且看起来相当可扩展。

广告

我已经从0.4.0 版本向lz package 添加了“改进的”解决方案,可以像这样使用

>>> from lz.replication import replicate
>>> iterable = iter(range(5))
>>> list(map(list, replicate(iterable,
                             count=3)))
[[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]

它使用hypothesis framework 进行了基于属性的测试,因此我们可以确定它按预期工作。

【讨论】:

    【解决方案2】:

    解决您的问题:如何深度复制嵌套的可迭代对象?

    您可以使用标准库中的deepcopy

    >>> from copy import deepcopy
    >>> 
    >>> ni = [1, [2,3,4]]
    >>> ci = deepcopy(ni)
    >>> ci[1][0] = "Modified"
    >>> ci
    [1, ['Modified', 3, 4]]
    >>> ni
    [1, [2,3,4]]
    

    更新

    @Azat Ibrakov 说:你正在处理序列,例如尝试对文件对象进行深度复制(提示:它会失败)

    不,对文件对象进行深拷贝,不会失败,可以深拷贝文件对象,演示:

    import copy
    
    with open('example.txt', 'w') as f:
         f.writelines(["{}\n".format(i) for i in range(100)])
    
    with open('example.txt', 'r') as f:
        l = [1, [f]]
        c = copy.deepcopy(l)
        print(isinstance(c[1][0], file))  # Prints  True.
        print("\n".join(dir(c[1][0])))
    

    打印:

    True
    __class__
    __delattr__
    __doc__
    __enter__
    __exit__
    __format__
    __getattribute__
    ...
    write
    writelines
    xreadlines
    

    问题出在概念上。

    根据 Python Iterator 协议,某些容器包含的项目是通过执行next 函数获得的,参见docs here

    在遍历整个迭代器之前,您不会拥有实现迭代器协议的对象的所有项(作为文件对象)(执行 next() 直到引发 StopIteration 异常)。

    这是因为您无法确定执行迭代器的next(Python 2.x 为__next__)方法的结果

    请看下面的例子:

    import random
    
    class RandomNumberIterator:
    
        def __init__(self):
            self.count = 0
            self.internal_it = range(10)  # For later demostration on deepcopy
    
        def __iter__(self):
            return self
    
        def next(self):
            self.count += 1
            if self.count == 10:
                raise StopIteration
            return random.randint(0, 1000)
    
    ri = RandomNumberIterator()
    
    for i in ri:
        print(i)  # This will print randor numbers each time.
                  # Can you come out with some sort of mechanism to be able
                  # to copy **THE CONTENT** of the `ri` iterator? 
    

    你也可以:

    from copy import deepcopy
    
    cri = deepcopy(ri)
    
    for i in cri.internal_it:
        print(i)   # Will print numbers 0..9
                   # Deepcopy on ri successful!
    

    文件对象在这里是一个特例,涉及到文件处理程序,之前,你可以看到你可以深度复制一个文件对象,但它会有closed状态。

    替代方案。

    你可以在你的迭代器上调用list,它会自动评估迭代器,然后你就可以再次测试迭代器的内容

    返回文件:

    with open('example.txt', 'w') as f:
             f.writelines(["{}\n".format(i) for i in range(5)])
    
    with open('example.txt', 'r') as f:
        print(list(f))  # Prints ['0\n', '1\n', '2\n', '3\n', '4\n']
    

    所以,继续

    您可以对嵌套的可迭代对象进行深度复制,但是,您无法在复制可迭代对象时对其进行评估,这毫无意义(请记住 RandomNumberIterator)。

    如果您需要测试可迭代对象CONTENT,您需要评估它们。

    【讨论】:

    • 您正在处理序列,例如尝试对文件对象进行深度复制(提示:它将失败)
    • 您使用的是哪个 Python 版本?对于 Python 3,deepcopying 文件对象将以 TypeError: cannot serialize '_io.TextIOWrapper' object 结尾
    • “你不能在被复制时评估迭代”是什么意思,我可以成功地使用itertools.tee 复制普通迭代,然后独立评估它们中的每一个,甚至可能无限的
    • 在 Python 2.7 中,如果我这样做 copy.deepcopy(file),我会得到 <closed file '<uninitialized file>', mode '<uninitialized file>' at 0x7f4a99ac3930>,并且在尝试像 list(file_copy) 那样对其进行迭代之后,它会引发 ValueError: I/O operation on closed file,而原始版本按预期工作,所以不,你无法使用 copy.deepcopy 函数制作文件对象的功能副本
    • @Azat Ibrakov ,在您的代码示例中,您使用的是类型注释,因此,我假设您使用的是 Python 3.x ,因此,我也使用 Python 3x 作为我的答案。另一方面,文件对象,与文件内容不一样。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-08-24
    • 2015-07-23
    • 2016-12-25
    • 1970-01-01
    相关资源
    最近更新 更多