【问题标题】:how to save torchtext Dataset?如何保存torchtext数据集?
【发布时间】:2019-04-24 14:36:43
【问题描述】:

我正在处理文本并使用torchtext.data.Dataset。 创建数据集需要大量时间。 对于仅运行程序,这仍然是可以接受的。但我想调试神经网络的火炬代码。如果 python 在调试模式下启动,数据集创建大约需要 20 分钟 (!!)。这只是为了获得一个可以调试神经网络代码的工作环境。

我想保存数据集,例如使用 pickle。此示例代码取自 here,但我删除了此示例不需要的所有内容:

from torchtext import data
from fastai.nlp import *

PATH = 'data/aclImdb/'

TRN_PATH = 'train/all/'
VAL_PATH = 'test/all/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

TEXT = data.Field(lower=True, tokenize="spacy")

bs = 64;
bptt = 70

FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)

with open("md.pkl", "wb") as file:
    pickle.dump(md, file)

要运行代码,您需要 aclImdb 数据集,它可以从here 下载。将其解压缩到此代码 sn-p 旁边的 data/ 文件夹中。代码在最后一行产生错误,其中使用了 pickle:

Traceback (most recent call last):
  File "/home/lhk/programming/fastai_sandbox/lesson4-imdb2.py", line 27, in <module>
    pickle.dump(md, file)
TypeError: 'generator' object is not callable

fastai 的样本经常使用dill 而不是pickle。但这对我也不起作用。

【问题讨论】:

    标签: python pickle pytorch torch torchtext


    【解决方案1】:

    如果您的数据集很小,pickle/dill 方法很好。但如果您使用的是大型数据集,我不会推荐它,因为它太慢了。

    我只是将示例(迭代地)保存为 JSON 字符串。这背后的原因是因为保存整个 Dataset 对象需要花费大量时间,而且您需要像 dill 这样的序列化技巧,这使得序列化变得更慢。

    此外,这些序列化程序会占用大量内存(其中一些甚至会创建数据集的副本),如果它们开始使用交换内存,那么您就完成了。该过程需要很长时间,您可能会在它完成之前终止它。

    因此,我最终采用了以下方法:

    1. 迭代示例
    2. 将每个示例(即时)转换为 JSON 字符串
    3. 将该 JSON 字符串写入文本文件(每个示例一个 行)
    4. 加载时,将示例与字段一起添加到 Dataset 对象中
    def save_examples(dataset, savepath):
        with open(savepath, 'w') as f:
            # Save num. elements (not really need it)
            f.write(json.dumps(total))  # Write examples length
            f.write("\n")
    
            # Save elements
            for pair in dataset.examples:
                data = [pair.src, pair.trg]
                f.write(json.dumps(data))  # Write samples
                f.write("\n")
    
    
    def load_examples(filename):
        examples = []
        with open(filename, 'r') as f:
            # Read num. elements (not really need it)
            total = json.loads(f.readline())
    
            # Save elements
            for i in range(total):
                line = f.readline()
                example = json.loads(line)
                # example = data.Example().fromlist(example, fields)  # Create Example obj. (you can do it here or later)
                examples.append(example)
    
        end = time.time()
        print(end - start)
        return examples
    

    然后,您可以通过以下方式简单地重建数据集:

    # Define fields
    SRC = data.Field(...)
    TRG = data.Field(...)
    fields = [('src', SRC), ('trg', TRG)]
    
    # Load examples from JSON and convert them to "Example objects"
    examples = load_examples(filename)
    examples = [data.Example().fromlist(d, fields) for d in examples]
    
    # Build dataset
    mydataset = Dataset(examples, fields)
    

    我使用 JSON 代替 pickle、dill、msgpack 等的原因并非随意。

    我做了一些测试,结果如下:

    Dataset size: 2x (1,960,641)
    
    Saving times:
    - Pickle/Dill*: >30-45 min (...or froze my computer)
    
    - MessagePack (iterative): 123.44 sec
      100%|██████████| 1960641/1960641 [02:03<00:00, 15906.52it/s]
    
    - JSON (iterative): 16.33 sec
      100%|██████████| 1960641/1960641 [00:15<00:00, 125955.90it/s]
    
    - JSON (bulk): 46.54 sec (memory problems)
    
    Loading times:
     - Pickle/Dill*: -
    
     - MessagePack (iterative): 143.79 sec
       100%|██████████| 1960641/1960641 [02:23<00:00, 13635.20it/s]
    
     - JSON (iterative): 33.83 sec
       100%|██████████| 1960641/1960641 [00:33<00:00, 57956.28it/s] 
    
     - JSON (bulk): 27.43 sec
    

    *与其他答案类似的方法

    【讨论】:

      【解决方案2】:

      我自己想出了以下函数:

      import dill
      from pathlib import Path
      
      import torch
      from torchtext.data import Dataset
      
      def save_dataset(dataset, path):
          if not isinstance(path, Path):
              path = Path(path)
          path.mkdir(parents=True, exist_ok=True)
          torch.save(dataset.examples, path/"examples.pkl", pickle_module=dill)
          torch.save(dataset.fields, path/"fields.pkl", pickle_module=dill)
      
      def load_dataset(path):
          if not isinstance(path, Path):
              path = Path(path)
          examples = torch.load(path/"examples.pkl", pickle_module=dill)
          fields = torch.load(path/"fields.pkl", pickle_module=dill)
          return Dataset(examples, fields)
      

      并不是说实际的对象可能有点不同,例如,如果你保存TabularDataset,那么load_dataset 返回一个类Dataset 的实例。这不太可能影响数据管道,但可能需要额外的测试努力。 在自定义标记器的情况下,它也应该是可序列化的(例如,没有 lambda 函数等)。

      【讨论】:

        【解决方案3】:

        您总是可以使用 pickle 转储对象,但请记住,转储字典或字段对象列表不受模块处理,因此最好先尝试分解列表

        将 DataSet 对象存储到 pickle 文件中以便以后轻松加载

        def save_to_pickle(dataSetObject,PATH):
            with open(PATH,'wb') as output:
                for i in dataSetObject:
                    pickle.dump(vars(i), output, pickle.HIGHEST_PROTOCOL)
        

        最艰难的事情还没有到来,是的,正在加载 pickle 文件.... ;)

        首先,尝试查找所有字段名称和字段属性,然后进行kill

        将pickle文件加载到DataSetObject中

        def load_pickle(PATH, FIELDNAMES, FIELD):
            dataList = []
            with open(PATH, "rb") as input_file:
                while True:
                    try:
                        # Taking the dictionary instance as the input Instance
                        inputInstance = pickle.load(input_file)
                        # plugging it into the list
                        dataInstance =  [inputInstance[FIELDNAMES[0]],inputInstance[FIELDNAMES[1]]]
                        # Finally creating an example objects list
                        dataList.append(Example().fromlist(dataInstance,fields=FIELD))
                    except EOFError:
                        break
        
            # At last creating a data Set Object
            exampleListObject = Dataset(dataList, fields=data_fields)
            return exampleListObject 
        

        这个骇人听闻的解决方案在我的案例中很有效,希望您会发现它对您的案例也很有用。

        顺便说一句,欢迎提出任何建议:)。

        【讨论】:

          【解决方案4】:

          你可以用莳萝代替泡菜。这个对我有用。 您可以保存一个 torchtext 字段,如

          TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True,fix_length=200,batch_first=True)
          with open("model/TEXT.Field","wb")as f:
               dill.dump(TEXT,f)
          

          并加载类似的字段

          with open("model/TEXT.Field","rb")as f:
               TEXT=dill.load(f)
          

          官方代码支持正在开发中,您可以关注https://github.com/pytorch/text/issues/451https://github.com/pytorch/text/issues/73

          【讨论】:

            猜你喜欢
            • 2022-07-14
            • 2022-10-15
            • 1970-01-01
            • 2022-10-24
            • 2021-09-24
            • 1970-01-01
            • 2022-06-13
            • 2017-10-27
            • 1970-01-01
            相关资源
            最近更新 更多