【问题标题】:AttributeError: Can't get attribute 'my_func' on <module '__main__' from 'main.py'>AttributeError: 无法从 \'main.py\'> 获取 <module \'__main__\' 上的属性 \'my_func\'
【发布时间】:2022-12-07 19:02:16
【问题描述】:

我想基于笔记本创建 Python 脚本,以使用相同的 .pkl 文件获取运行时。


在这条线上:

learn = load_learner('model.pkl', cpu=True)

我收到此错误:

(project) daniel@ubuntu-pcs:~/PycharmProjects/project$ python main.py 
Traceback (most recent call last):
  File "main.py", line 6, in <module>
    from src.train.train_model import train
  File "/home/daniel/PycharmProjects/project/src/train/train_model.py", line 17, in <module>
    learn = load_learner('yasmine-sftp/export_2.pkl', cpu=True)  # to run on GPU
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/site-packages/fastai/learner.py", line 384, in load_learner
    res = torch.load(fname, map_location='cpu' if cpu else None, pickle_module=pickle_module)
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/site-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/site-packages/torch/serialization.py", line 875, in find_class
    return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Tf' on <module '__main__' from 'main.py'>

这是因为为了打开 .pkl 文件,我需要用于训练它的原始函数。

值得庆幸的是,回头看看笔记本,Tf(o) 就在那里:

def Tf(o):
    return '/mnt/scratch2/DLinTHDP/PathLAKE/Version_4_fastai/Dataset/CD8/Train/masks/'+f'{o.stem}_P{o.suffix}'

但是,无论我在 Python 脚本中放置 Tf(o) 的什么地方,我仍然会遇到同样的错误。

我应该把Tf(o)放在哪里?

在错误信息中:&lt;module '__main__' from 'main.py'&gt;似乎建议把它放在main()if __name__ ...下。

我到处都试过了。导入Tf(o) 也不起作用。


Python 脚本

main.py:

import glob
from pathlib import Path

from train_model import train

ROOT = Path("folder/path")  # Detection Folder


def main(root: Path):
    train(root)


if __name__ == '__main__':
    main(ROOT)

train_model.py

from pathlib import Path

from fastai.vision.all import *


folder_path = Path('.')

learn = load_learner('model.pkl', cpu=True)  # AttributeError
learn.load('model_3C_34_CELW_V_1.1')  # weights


def train(root: Path):
    # ...

我无法检查文件:

(project) daniel@ubuntu-pcs:~/PycharmProjects/project$ python -m pickletools -a model.pkl
Traceback (most recent call last):
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/pickletools.py", line 2830, in <module>
    args.indentlevel, annotate)
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/pickletools.py", line 2394, in dis
    for opcode, arg, pos in genops(pickle):
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/pickletools.py", line 2242, in _genops
    arg = opcode.arg.reader(data)
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/pickletools.py", line 373, in read_stringnl_noescape
    return read_stringnl(f, stripquotes=False)
  File "/home/daniel/miniconda3/envs/project/lib/python3.6/pickletools.py", line 359, in read_stringnl
    data = codecs.escape_decode(data)[0].decode("ascii")
UnicodeDecodeError: 'ascii' codec can't decode byte 0x80 in position 63: ordinal not in range(128)

【问题讨论】:

    标签: python-3.x jupyter-notebook pycharm attributeerror fast-ai


    【解决方案1】:

    问题

    为什么我得到这个错误是因为Tf()函数被用来训练model.pkl文件,在同一个命名空间(因为它是在笔记本文件中完成的)。

    article 声明:

    pickle 是懒惰的,不序列化类定义或函数 定义。相反,它保存了如何查找类的参考 (它所在的模块及其名称)

    解决方案

    pickle 有一个名为 dill 的扩展,它可以序列化 Python 对象和函数等(不是引用)PyPI

    【讨论】:

      【解决方案2】:

      根据 ecatkins Edward Atkins 的说法,我通过修改 fastai.basic_train.py (fastai==1.0.61) 中的 load_learner 解决了类似的问题 https://forums.fast.ai/t/error-loading-saved-model-with-custom-loss-function/37627/7

      当在原始模块中找不到用于定义模型的类时,load_learner 失败。因此,覆盖 pickle 类加载器 (find_class) 以在指定模块中搜索。

      import imp, sys
      pickle2 = imp.load_module('pickle2', *imp.find_module('pickle'))
      
      # The module where the class is now found.
      MODULE = "MY.MODULE"
      
      class CustomUnpickler(pickle2.Unpickler):
      
          def find_class(self, module, name):
              try:
                  return super().find_class(module, name)
              except AttributeError:
                  if module == "__main__":
                      print(f"load_learner can't find {name} in original module {module}; getting it from {MODULE}")
                      module = MODULE
                      return super().find_class(module, name)
      
      
      # Modified load_learner from fastai.basic_train.py (fastai==1.0.61), according to ecatkins Edward Atkins
      # https://forums.fast.ai/t/error-loading-saved-model-with-custom-loss-function/37627/7
      def load_learner2(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, tfm_y=None, **db_kwargs):
          "Load a `Learner` object saved with `export_state` in `path/file` with empty data, optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer)"
          source = Path(path)/file if is_pathlike(file) else file
          # state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
          # Use custom class loader here
          pickle2.Unpickler = CustomUnpickler
          state = torch.load(source, map_location='cpu',  pickle_module=pickle2) if defaults.device == torch.device('cpu') else torch.load(source,  pickle_module=pickle2)
          model = state.pop('model')
          src = LabelLists.load_state(path, state.pop('data'))
          if test is not None: src.add_test(test, tfm_y=tfm_y)
          data = src.databunch(**db_kwargs)
          cb_state = state.pop('cb_state')
          clas_func = state.pop('cls')
          res = clas_func(data, model, **state)
          res.callback_fns = state['callback_fns'] #to avoid duplicates
          res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()]
          return res
      

      【讨论】:

        【解决方案3】:

        另一种方法是手动复制训练环境的命名空间。

        就我而言

        global Tf
        Tf = None
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2022-06-20
          • 1970-01-01
          • 2021-10-08
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多