【发布时间】:2022-01-11 16:19:27
【问题描述】:
我有一个经过微调的简单变压器表示模型。现在我想仅以 pickle 格式保存池层的权重,并将其放入我正在设计的另一个自定义自动编码器的池层中。如何使用 pytorch 和 python 来做到这一点?
【问题讨论】:
标签: python nlp pytorch bert-language-model simpletransformers
我有一个经过微调的简单变压器表示模型。现在我想仅以 pickle 格式保存池层的权重,并将其放入我正在设计的另一个自定义自动编码器的池层中。如何使用 pytorch 和 python 来做到这一点?
【问题讨论】:
标签: python nlp pytorch bert-language-model simpletransformers
每个 PyTorch 模块旁边都有一个名为 state_dict 的对象,它允许将任何参数映射到其对应的张量变量(更多关于此 here)。使用此实用程序,您可以轻松保存和加载参数,但请记住,您必须事先确定您想要在语义上(从机器学习的角度)和语法上(形状兼容性和......)做什么!下面的实现将使用我们之前保存的模型中的相应变量替换名称中带有单词pooling 的任何参数。
finetuned_model = BertLMHeadModel.from_pretrained('bert-base-cased')
torch.save(finetuned_model.state_dict(), "finetuned_model.pth")
finetuned_model_state_dict = torch.load("finetuned_model.pth")
new_model = BertLMHeadModel.from_pretrained('bert-base-cased')
new_model_state_dict = new_model.state_dict()
for key, value in new_model_state_dict.items():
if key.find('pooling')!=-1:
new_model_state_dict.update({key: value})
【讨论】: