【问题标题】:Using torchtext for inference使用 torchtext 进行推理
【发布时间】:2020-10-25 22:11:38
【问题描述】:

我想知道使用torchtext 进行推理的正确方法是什么。

假设我已经训练了模型,并且 dump 所有具有内置词汇表的字段。下一步似乎是使用torchtext.data.Example 加载一个示例。不知何故,我应该使用加载的字段将其数字化并创建一个迭代器。

如果有任何使用torchtext 进行推理的简单示例,我将不胜感激。

【问题讨论】:

    标签: machine-learning nlp pytorch torchtext


    【解决方案1】:

    对于经过训练的模型和词汇(这是文本字段的一部分,您不必保存整个类):

        def read_vocab(path):
            #read vocabulary pkl 
            import pickle
            pkl_file = open(path, 'rb')
            vocab = pickle.load(pkl_file)
            pkl_file.close()
            return vocab
    
    
    
        def load_model_and_vocab():
            import torch
            import os.path
        
            my_path = os.path.abspath(os.path.dirname(__file__))
            vocab_path = os.path.join(my_path, vocab_file)
            weights_path = os.path.join(my_path, WEIGHTS)
        
            vocab = read_vocab(vocab_path)
            model = classifier(vocab_size=len(vocab))
            model.load_state_dict(torch.load(weights_path))
            model.eval()
            return model, vocab
        
        
        def predict(model, vocab, sentence):
            tokenized = [w.text.lower() for w in nlp(sentence)]  # tokenize the sentence
            indexed = [vocab.stoi[t] for t in tokenized]         # convert to integer sequence
            length = [len(indexed)]                              # compute no. of words
            tensor = torch.LongTensor(indexed).to('cpu')         # convert to tensor
            tensor = tensor.unsqueeze(1).T                       # reshape in form of batch,no. of words
            length_tensor = torch.LongTensor(length)             # convert to tensor
            prediction = model(tensor, length_tensor)            # prediction
            return round(1-prediction.item())
    

    “分类器”是我为我的模型定义的类。

    用于保存词汇 pkl:

        def save_vocab(vocab):
            import pickle
            output = open('vocab.pkl', 'wb')
            pickle.dump(vocab, output)
            output.close()
    

    为了在训练后保存模型,您可以使用:

        torch.save(model.state_dict(), 'saved_weights.pt')
    

    告诉我它是否适合你!

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-04-20
      • 2023-03-25
      • 1970-01-01
      • 1970-01-01
      • 2020-10-25
      • 2021-09-11
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多