【问题标题】:Pytorch save embeddings as part of encoder class or notPytorch 是否将嵌入保存为编码器类的一部分
【发布时间】:2018-09-05 02:34:22
【问题描述】:

所以我是第一次使用 pytorch。我正在尝试将权重保存到文件中。我正在使用具有 GRU 和嵌入组件的编码器类。我想确保当我保存编码器值时,我将获得嵌入值。最初,我的代码使用 state_dict() 将值复制到我自己的字典中,然后将其传递给 torch.save()。我应该寻找某种方法来保存这个嵌入组件还是它是更大编码器的一部分? Encoder 是 nn.Module 的子类。这是一个链接:

http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#sphx-glr-intermediate-seq2seq-translation-tutorial-py

def make_state(self, converted=False):
    if not converted:
        z = [
            {
                'epoch':0,
                'arch': None,
                'state_dict': self.model_1.state_dict(),
                'best_prec1': None,
                'optimizer': self.opt_1.state_dict(),
                'best_loss': self.best_loss
            },
            {
                'epoch':0,
                'arch':None,
                'state_dict':self.model_2.state_dict(),
                'best_prec1':None,
                'optimizer': self.opt_2.state_dict(),
                'best_loss': self.best_loss
            }
        ]
    else:
        z = [
            {
                'epoch': 0,
                'arch': None,
                'state_dict': self.model_1.state_dict(),
                'best_prec1': None,
                'optimizer': None , # self.opt_1.state_dict(),
                'best_loss': self.best_loss
            },
            {
                'epoch': 0,
                'arch': None,
                'state_dict': self.model_2.state_dict(),
                'best_prec1': None,
                'optimizer': None, # self.opt_2.state_dict(),
                'best_loss': self.best_loss
            }
        ]
    #print(z)
    return z
    pass

def save_checkpoint(self, state=None, is_best=True, num=0, converted=False):
    if state is None:
        state = self.make_state(converted=converted)
        if converted: print(converted, 'is converted.')
    basename = hparams['save_dir'] + hparams['base_filename']
    torch.save(state, basename + '.' + str(num)+ '.pth.tar')
    if is_best:
        os.system('cp '+ basename + '.' + str(num) + '.pth.tar' + ' '  +
                  basename + '.best.pth.tar')

https://discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/2610/3

这是另一个链接

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    不,您不需要显式保存嵌入值。保存模型的 state_dict 将保存与该模型相关的所有变量,包括嵌入权重。
    您可以通过循环来查找状态字典包含的内容 -

    for var_name in model.state_dict():
        print(var_name)
    

    【讨论】:

      猜你喜欢
      • 2021-11-08
      • 2021-11-17
      • 2016-03-30
      • 2020-04-19
      • 1970-01-01
      • 2020-02-07
      • 1970-01-01
      • 2021-09-25
      • 1970-01-01
      相关资源
      最近更新 更多