【问题标题】:*why* does multiprocessing serialize my function and closure?*为什么*多处理序列化我的函数和闭包?
【发布时间】:2019-10-17 23:46:55
【问题描述】:

根据https://docs.python.org/3/library/multiprocessing.html multiprocessing fork (for *nix) 创建一个工作进程来执行任务。我们可以通过在 fork 之前的模块中设置全局变量来验证这一点。 如果工作函数导入该模块并发现变量存在,则进程内存已被复制。就是这样:

import os

def f(x):
    import sys
    return sys._mypid  # <<< value is returned by subprocess!


def set_state():
    import sys
    sys._mypid = os.getpid()

def g():
    from multiprocessing import Pool
    pool = Pool(4)
    try:
        for z in pool.imap(f, range(1000)):
            print(z)
    finally:
        pool.close()
        pool.join()

if __name__=='__main__':
    set_state()
    g()

但是,如果事情是这样工作的,那么多处理在序列化工作功能方面有什么业务,f?

在这个例子中:

import os

def set_state():
    import sys
    sys._mypid = os.getpid()

def g():
    def f(x):
        import sys
        return sys._mypid

    from multiprocessing import Pool
    pool = Pool(4)
    try:
        for z in pool.imap(f, range(1000)):
            print(z)
    finally:
        pool.close()
        pool.join()

if __name__=='__main__':
    set_state()
    g()

我们得到:

AttributeError: Can't pickle local object 'g.<locals>.f'

Stackoverflow 和互联网上有很多方法可以解决这个问题。 (Python 的标准 pickle 函数可以处理函数,但不能处理带有闭包数据的函数)。

但是我们为什么会来到这里? f 的写时复制版本位于分叉进程的内存中。为什么需要序列化呢?

【问题讨论】:

  • 没关系——想通了。这是因为f 在调用Pool 并且进程被分叉时不可用。
  • 你可以为自己的问题写一个答案
  • @user48956:请注意,这不是 timing 的事情:f 在构造 Pool 时存在,但它无法知道将参考它,以避免在imap 期间需要传输一些关于它的描述。可以想象在fork 时提供了一个非全局对象的“护理包”,但由于这对spawn 方法没有帮助,因此它可能不是优先事项。
  • 对于生成物来说确实如此。对于分叉来说不是这样。

标签: python multiprocessing python-multiprocessing


【解决方案1】:

Derp -- 它必须是这样的,因为:

    pool = Pool(4)  <<< processes created here

    for z in pool.imap(f, range(1000)):   <<< reference to function

仅供参考...任何想要 fork 的人,新进程可以访问函数(从而避免序列化函数),都可以遵循以下模式:

import collections
import multiprocessing as mp
import os
import pickle
import threading

_STATUS_DATA = 0
_STATUS_ERR = 1
_STATUS_POISON = 2


Message = collections.namedtuple(
    "Message",
    ["status",
     "payload",
     "sequence_id"
     ]
)

def parallel_map(
        target,
        args,
        num_processes,
        inq_maxsize=None,
        outq_maxsize=None,
        serialize=pickle.dumps,
        deserialize=pickle.loads,
        start_method="fork",
        preserve_order=True,
):
    """
    :param target: Target function
    :param args: Iterable of single parameter arguments for target.
    :param num_processes: Number of processes.
    :param inq_maxsize:
    :param outq_maxsize:
    :param serialize:
    :param deserialize:
    :param start_method:
    :param preserve_order: If true result are returns in the order received by args. Otherwise,
      first result is returned first
    :return:
    """
    if inq_maxsize is None: inq_maxsize=10*num_processes
    if outq_maxsize is None: outq_maxsize=10*num_processes
    inq = mp.Queue(maxsize=inq_maxsize)
    outq = mp.Queue(maxsize=outq_maxsize)
    poison = serialize(Message(_STATUS_POISON, None, -1))
    deserialize(poison) # Test

    def work():
        while True:
            obj = inq.get()
            # print("{} - GET .. OK".format(os.getpid()))
            # inq.task_done()

            try:
                msg = deserialize(obj)
                assert isinstance(msg, Message)
                if msg.status==_STATUS_POISON:
                    outq.put(serialize(Message(_STATUS_POISON,None,msg.sequence_id)))
                    # print("{} - RETURN POISON .. OK".format(os.getpid()))
                    return
                else:
                    args, kw = msg.payload
                    result = target(*args,**kw)
                    outq.put(serialize(Message(_STATUS_DATA,result,msg.sequence_id)))

            except Exception as e:
                try:
                    outq.put(serialize(Message(_STATUS_ERR,e,msg.sequence_id)))
                except Exception as e2:
                    try:
                        outq.put(serialize(Message(_STATUS_ERR,None,-1)))
                        # outq.put(serialize(1,Exception("Unable to serialize response")))
                        # TODO. Log exception
                    except Exception as e3:
                        pass

    if start_method == "thread":
        _start_method = threading.Thread
    else:
        _start_method = mp.get_context('fork').Process

    processes = [
        _start_method(
            target=work,
            name="parallel_map.work"
        )
        for _ in range(num_processes)]

    for p in processes:
        p.start()

    quitting = []
    def quit_processes():
        if not quitting:
            quitting.append(1)
        # Send poison pills - kill child processes
        for _ in range(num_processes):
            inq.put(poison)

    nsent = [0]
    def send():
        # Send the data
        for seq_id, arg in enumerate(args):
            obj = ((arg,), {})
            inq.put(serialize(Message(_STATUS_DATA, obj, seq_id)))
            nsent[0] += 1
        quit_processes()

    # Publish
    sender = threading.Thread(
        target=send,
        name="parallel_map.sender",
        daemon=True)
    sender.start()

    try:
        # Consume
        nquit = [0]
        buffer = {}
        nyielded = 0
        while True:
            result = outq.get() # Waiting here
            # outq.task_done()
            msg = deserialize(result)
            assert isinstance(msg, Message)
            if msg.status == _STATUS_POISON:
                nquit[0]+=1
                # print(">>> QUIT ACK {}".format(nquit[0]))
                if nquit[0]>=num_processes:
                    break
            else:
                assert msg.sequence_id>=0

                if preserve_order:
                    buffer[msg.sequence_id] = msg
                    while True:
                        if nyielded not in buffer:
                            break

                        msg = buffer.pop(nyielded)
                        nyielded += 1
                        if msg.status==_STATUS_ERR:
                            if isinstance(msg.payload, Exception):
                                raise msg.payload
                            else:
                                raise Exception("Unexpected exception")
                        else:
                            assert msg.status==_STATUS_DATA
                            yield msg.payload
                else:
                    if msg.status==_STATUS_ERR:
                        if isinstance(msg.payload, Exception):
                            raise msg.payload
                        else:
                            raise Exception("Unexpected exception")
                    else:
                        assert msg.status==_STATUS_DATA
                        yield msg.payload


                # if nyielded == nsent:
                #     break

    except Exception as e:
        raise
    finally:
        if not quitting:
            quit_processes()
        sender.join()
        for p in processes:
            p.join()


        def f(x):
            time.sleep(0.01)
            if x ==-1:
                raise Exception("Boo")
            return x

用法:

        def f(x):
            time.sleep(0.01)
            if x ==-1:
                raise Exception("Boo")
            return x

        for result in parallel_map(target=f,  <<< not serialized
                                   args=range(100),
                                   num_processes=8,
                                   start_method="fork"):
            pass

... 需要注意的是:当你 fork 时,程序中的每个线程都会死掉一只小狗。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-10-07
    • 2010-10-20
    • 2020-03-07
    • 2020-03-18
    • 2017-03-10
    • 1970-01-01
    相关资源
    最近更新 更多