【发布时间】:2021-07-05 06:06:09
【问题描述】:
由于 Pytorch Lightning 为模型检查点提供自动保存功能,我使用它来保存 top-k 最佳模型。特别是在 Trainer 设置中,
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='checkpoints/',
filename='{epoch:02d}-{val_acc:.2f}',
save_top_k=5,
mode='max',
)
这很好,但它没有保存模型对象的某些属性。我的模型在每个训练周期结束时都会存储一些张量,这样
class SampleNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Linear(100, 1)
self.loss = torch.nn.CrossEntropy()
self.some_data = None # Initialize as None
def training_step(self, batch):
x, t = batch
out = self.layer(x)
loss = self.loss(out, t)
results = {'loss': loss}
return results
def training_epoch_end(self, outputs):
self.some_data = some_tensor_object
这是一个简化的示例,但我希望上面checkpoint_callback 制作的检查点文件记住属性self.some_data,但是当我从检查点加载模型时,它总是重置为None。我确认在训练过程中更新成功。
我尝试不在init 中将其初始化为None,但加载模型时该属性会消失。
我想避免将属性保存为不同的pt 文件,因为它与模型配置相关联,因此我需要稍后手动将文件与相应的检查点文件匹配。
是否可以在检查点文件中包含这样的张量属性?
【问题讨论】:
标签: python pytorch pytorch-lightning