【问题标题】:Loading pool layer of simple transformer简单变压器的加载池层
【发布时间】:2022-01-11 16:19:27
【问题描述】:

我有一个经过微调的简单变压器表示模型。现在我想仅以 pickle 格式保存池层的权重,并将其放入我正在设计的另一个自定义自动编码器的池层中。如何使用 pytorch 和 python 来做到这一点?

【问题讨论】:

    标签: python nlp pytorch bert-language-model simpletransformers


    【解决方案1】:

    每个 PyTorch 模块旁边都有一个名为 state_dict 的对象,它允许将任何参数映射到其对应的张量变量(更多关于此 here)。使用此实用程序,您可以轻松保存和加载参数,但请记住,您必须事先确定您想要在语义上(从机器学习的角度)和语法上(形状兼容性和......)做什么!下面的实现将使用我们之前保存的模型中的相应变量替换名称中带有单词pooling 的任何参数。

    finetuned_model = BertLMHeadModel.from_pretrained('bert-base-cased')
    torch.save(finetuned_model.state_dict(), "finetuned_model.pth")
    finetuned_model_state_dict = torch.load("finetuned_model.pth")
    new_model = BertLMHeadModel.from_pretrained('bert-base-cased')
    new_model_state_dict = new_model.state_dict()
    for key, value in new_model_state_dict.items():
      if key.find('pooling')!=-1:
        new_model_state_dict.update({key: value})
    

    【讨论】:

    • @Parmida Granfar 的答案有帮助吗?
    猜你喜欢
    • 2016-03-19
    • 2021-08-03
    • 2021-10-02
    • 1970-01-01
    • 1970-01-01
    • 2021-05-10
    • 2021-12-23
    • 1970-01-01
    • 2023-01-21
    相关资源
    最近更新 更多