【问题标题】:Serializing an object in __main__ with pickle or dill用 pickle 或 dill 序列化 __main__ 中的对象
【发布时间】:2017-08-10 14:31:20
【问题描述】:

我有一个酸洗问题。我想在我的主脚本中序列化一个函数,然后加载它并在另一个脚本中运行它。为了证明这一点,我制作了 2 个脚本:

尝试一:天真的方式:

dill_pickle_script_1.py

import pickle
import time

def my_func(a, b):
    time.sleep(0.1)  # The purpose of this will become evident at the end
    return a+b

if __name__ == '__main__':
    with open('testfile.pkl', 'wb') as f:
        pickle.dump(my_func, f)

dill_pickle_script_2.py

import pickle

if __name__ == '__main__':
    with open('testfile.pkl') as f:
        func = pickle.load(f)
        assert func(1, 2)==3

问题:当我运行脚本 2 时,我得到AttributeError: 'module' object has no attribute 'my_func'。我明白为什么:因为当 my_func 在 script1 中序列化时,它属于 __main__ 模块。 dill_pickle_script_2 不知道那里的__main__ 引用了 dill_pickle_script_1 的命名空间,因此找不到引用。

尝试 2:插入绝对导入

我通过添加一个小技巧解决了这个问题 - 在酸洗之前,我在 dill_pickle_script_1 中添加了一个绝对导入到 my_func。

dill_pickle_script_1.py

import pickle
import time

def my_func(a, b):
    time.sleep(0.1)
    return a+b

if __name__ == '__main__':
    from dill_pickle_script_1 import my_func  # Added absolute import
    with open('testfile.pkl', 'wb') as f:
        pickle.dump(my_func, f)

现在可以了!但是,我想避免每次我想这样做时都必须这样做。 (另外,我想让我的酸洗在其他一些不知道 my_func 来自哪个模块的模块中完成)。

尝试 3:莳萝

我认为 dill 包可以让您在 main 中序列化内容并将它们加载到其他地方。所以我尝试了:

dill_pickle_script_1.py

import dill
import time

def my_func(a, b):
    time.sleep(0.1)
    return a+b

if __name__ == '__main__':
    with open('testfile.pkl', 'wb') as f:
        dill.dump(my_func, f)

dill_pickle_script_2.py

import dill

if __name__ == '__main__':
    with open('testfile.pkl') as f:
        func = dill.load(f)
        assert func(1, 2)==3

然而,现在我遇到了另一个问题:运行dill_pickle_script_2.py 时,我得到了NameError: global name 'time' is not defined。似乎 dill 没有意识到 my_func 引用了 time 模块并且必须在加载时导入它。

我的问题?

我如何在 main 中序列化一个对象,然后在另一个脚本中再次加载它,以便该对象使用的所有导入也被加载,而无需在尝试 2 中进行令人讨厌的小技巧?

【问题讨论】:

  • Conchylicultor's 答案可能就是您要找的。但请记住,您还应该以二进制模式阅读腌制/挖出的文件:open(testfile.pkl, 'rb')

标签: python pickle dill


【解决方案1】:

嗯,我找到了解决办法。这是一个可怕但整洁的组合,不能保证在所有情况下都能正常工作。欢迎任何改进建议。该解决方案包括使用以下辅助函数将主引用替换为 pickle 字符串中的绝对模块引用:

import sys
import os

def pickle_dumps_without_main_refs(obj):
    """
    Yeah this is horrible, but it allows you to pickle an object in the main module so that it can be reloaded in another
    module.
    :param obj:
    :return:
    """
    currently_run_file = sys.argv[0]
    module_path = file_path_to_absolute_module(currently_run_file)
    pickle_str = pickle.dumps(obj, protocol=0)
    pickle_str = pickle_str.replace('__main__', module_path)  # Hack!
    return pickle_str


def pickle_dump_without_main_refs(obj, file_obj):
    string = pickle_dumps_without_main_refs(obj)
    file_obj.write(string)


def file_path_to_absolute_module(file_path):
    """
    Given a file path, return an import path.
    :param file_path: A file path.
    :return:
    """
    assert os.path.exists(file_path)
    file_loc, ext = os.path.splitext(file_path)
    assert ext in ('.py', '.pyc')
    directory, module = os.path.split(file_loc)
    module_path = [module]
    while True:
        if os.path.exists(os.path.join(directory, '__init__.py')):
            directory, package = os.path.split(directory)
            module_path.append(package)
        else:
            break
    path = '.'.join(module_path[::-1])
    return path

现在,我可以简单地将dill_pickle_script_1.py 改为说

import time
from artemis.remote.child_processes import pickle_dump_without_main_refs


def my_func(a, b):
    time.sleep(0.1)
    return a+b

if __name__ == '__main__':
    with open('testfile.pkl', 'wb') as f:
        pickle_dump_without_main_refs(my_func, f)

然后dill_pickle_script_2.py 工作!

【讨论】:

    【解决方案2】:

    您可以将dill.dumprecurse=Truedill.settings["recurse"] = True 一起使用。它将捕获闭包:

    在文件 A 中:

    import time
    import dill
    
    def my_func(a, b):
      time.sleep(0.1)
      return a + b
    
    with open("tmp.pkl", "wb") as f:
      dill.dump(my_func, f, recurse=True)
    

    在文件 B 中:

    import dill
    
    with open("tmp.pkl", "rb") as f:
      my_func = dill.load(f)
    

    【讨论】:

      【解决方案3】:

      这是另一种修改序列化的解决方案,以便在没有任何特殊措施的情况下反序列化。你可能会说它没有Peter's solution 那么老套。

      不是破解pickle.dumps() 的输出,而是子类Pickler 来修改它腌制引用回__main__ 的对象的方式。这确实意味着不能使用快速(C 实现)pickler,因此这种方法会降低性能。它还覆盖了Picklersave_pers() 方法,该方法不打算被覆盖。所以这可能会在 Python 的未来版本中出现问题(虽然不太可能)。

      def get_function_module_str(func):
          """Returns a dotted module string suitable for importlib.import_module() from a
          function reference.
          """
          source_file = Path(inspect.getsourcefile(func))
          # (Doesn't work with built-in functions)
          if not source_file.is_absolute():
              rel_path = source_file
          else:
              # It's an absolute path so find the longest entry in sys.path that shares a
              # common prefix and remove the prefix.
              for path_str in sorted(sys.path, key=len, reverse=True):
                  try:
                      rel_path = source_file.relative_to(Path(path_str))
                      break
                  except ValueError:
                      pass
              else:
                  raise ValueError(f"{source_file!r} is not on the Python path")
          # Replace path separators with dots.
          modules_str = ".".join(p for p in rel_path.with_suffix("").parts if p != "__init__")
          return modules_str, func.__name__
      
      
      class ResolveMainPickler(pickle._Pickler):
          """Subclass of Pickler that replaces __main__ references with the actual module
          name."""
      
          def persistent_id(self, obj):
              """Override to see if this object is defined in "__main__" and if so to replace
              __main__ with the actual module name."""
              if getattr(obj, "__module__", None) == "__main__":
                  module_str, obj_name = get_function_module_str(obj)
                  obj_ref = getattr(importlib.import_module(module_str), obj_name)
                  return obj_ref
              return None
      
          def save_pers(self, pid):
              """Override the function to save a persistent ID so that it saves it as a
              normal reference. So it can be unpickled with no special arrangements.
              """
              self.save(pid, save_persistent_id=False)
      
      
      with io.BytesIO() as pickled:
          pickler = ResolveMainPickler(pickled)
          pickler.dump(obj)
          print(pickled.getvalue())
      

      如果您已经知道 __main__ 模块的名称,那么您可以省去 get_function_module_str() 并直接提供名称。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2019-04-09
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2019-03-26
        • 2015-11-24
        相关资源
        最近更新 更多