【发布时间】:2021-08-12 23:55:15
【问题描述】:
我想使用 pytorch-lightning 将模型权重保存到 mlflow 跟踪。 pytorch-lightning 支持logging。 但是,似乎不支持将模型权重保存为 mlflow 上的工件。
一开始是打算重写ModelCheckpoint类来做的,但是因为复杂的mixin操作,我发现很难。
有人知道简单的方法吗?
【问题讨论】:
标签: pytorch
我想使用 pytorch-lightning 将模型权重保存到 mlflow 跟踪。 pytorch-lightning 支持logging。 但是,似乎不支持将模型权重保存为 mlflow 上的工件。
一开始是打算重写ModelCheckpoint类来做的,但是因为复杂的mixin操作,我发现很难。
有人知道简单的方法吗?
【问题讨论】:
标签: pytorch
正如@xela 所说,您可以使用mlflow 记录器的experiment 对象来记录工件。
如果您想在训练期间频繁记录模型权重,可以扩展 ModelCheckpoint:
class MLFlowModelCheckpoint(ModelCheckpoint):
def __init__(self, mlflow_logger, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mlflow_logger = mlflow_logger
@rank_zero_only
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
run_id = self.mlflow_logger.run_id
self.mlflow_logger.experiment.log_artifact(run_id, self.best_model_path)
然后在你的训练代码中使用
mlflow_logger = MLFlowLogger()
checkpoint_callback = MLFlowModelCheckpoint(mlflow_logger)
trainer = pl.Trainer(checkpoint_callback=checkpoint_callback, logger=mlflow_logger)
【讨论】:
on_validation_end - 也在基类中 - 仅在排名为零的过程中被调用。我建议添加另一个用@rank_zero_only 装饰的方法(upload_best_model)并从on_validation_end 调用它(始终执行)。
@stecklin 的替代方法是使用记录器的after_save_checkpoint 方法。
你可以像这样扩展MLFlowLogger:
class MLFlowLoggerCheckpointer(pl.MLFlowLogger):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def after_save_checkpoint(self, model_checkpoint: pl.ModelCheckpoint) -> None:
"""
Called after model checkpoint callback saves a new checkpoint.
"""
self.experiment.log_artifact(
self.run_id, model_checkpoint.best_model_path
)
【讨论】: