【问题标题】:PyTorch Lightning: includes some Tensor objects in checkpoint filePyTorch Lightning:在检查点文件中包含一些张量对象
【发布时间】:2021-07-05 06:06:09
【问题描述】:

由于 Pytorch Lightning 为模型检查点提供自动保存功能,我使用它来保存 top-k 最佳模型。特别是在 Trainer 设置中,

    checkpoint_callback = ModelCheckpoint(
                            monitor='val_acc',
                            dirpath='checkpoints/',
                            filename='{epoch:02d}-{val_acc:.2f}',
                            save_top_k=5,
                            mode='max',
                            )

这很好,但它没有保存模型对象的某些属性。我的模型在每个训练周期结束时都会存储一些张量,这样

class SampleNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.layer = torch.nn.Linear(100, 1)
        self.loss = torch.nn.CrossEntropy()

        self.some_data = None # Initialize as None

    def training_step(self, batch):
        x, t = batch
        out = self.layer(x)
        loss = self.loss(out, t)
        results = {'loss': loss}

        return results

    def training_epoch_end(self, outputs):
       
        self.some_data = some_tensor_object

这是一个简化的示例,但我希望上面checkpoint_callback 制作的检查点文件记住属性self.some_data,但是当我从检查点加载模型时,它总是重置为None。我确认在训练过程中更新成功。

我尝试不在init 中将其初始化为None,但加载模型时该属性会消失。

我想避免将属性保存为不同的pt 文件,因为它与模型配置相关联,因此我需要稍后手动将文件与相应的检查点文件匹配。

是否可以在检查点文件中包含这样的张量属性?

【问题讨论】:

    标签: python pytorch pytorch-lightning


    【解决方案1】:

    只需将模型类挂钩 on_save_checkpoint()on_load_checkpoint() 用于您想要与默认属性一起保存的各种对象。

    def on_save_checkpoint(self, checkpoint) -> None:
        "Objects to include in checkpoint file"
        checkpoint["some_data"] = self.some_data
    
    def on_load_checkpoint(self, checkpoint) -> None:
        "Objects to retrieve from checkpoint file"
        self.some_data= checkpoint["some_data"]
    

    See module docs

    【讨论】:

      【解决方案2】:

      似乎不能直接使用,因为提取参数最有可能使用nn.Module.state_dict()。 此方法仅提取实际被视为参数的张量的值。因此,在这种情况下,一种解决方法是将您的数据保存为参数(请参阅docs):

      self.some_data = torch.nn.parameter.Parameter(your_data)
      

      【讨论】:

        猜你喜欢
        • 2021-05-14
        • 2021-10-05
        • 2021-12-21
        • 2021-01-15
        • 2021-02-11
        • 1970-01-01
        • 2020-10-29
        • 2020-07-17
        • 2021-05-08
        相关资源
        最近更新 更多