【发布时间】:2017-10-26 12:28:23
【问题描述】:
我写了一个 scikit-learn 估计器。它有一个参数和一个由fit 设置的model_ 属性。
class MyEstimator(BaseEstimator, TransformerMixin):
def __init__(self, param="default"):
self.param = param
self.model_ = None
def fit(self, x, y):
# Sets the value of self.model_
我希望能够腌制MyEstimator,但我创建的model_ 对象不能用pickle 序列化,因为它是keras 模型。按照博文“Pickling Keras Models”的示例,我将以下酸洗处理程序方法添加到我的类中。
class MyEstimator(BaseEstimator, TransformerMixin):
def __getstate__(self):
state = super().__getstate__().copy()
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
keras.models.save_model(self.model_, fd.name, overwrite=True)
state["model_"] = fd.read()
return state
def __setstate__(self, state):
super().__setstate__(state)
with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
fd.write(state["model_"])
fd.flush()
self.__dict__["model_"] = keras.models.load_model(fd.name)
这将不可腌制的model_ 成员替换为由 keras 的序列化程序生成的可以腌制的表示。使用此自定义,我可以调用fit,进行序列化和反序列化,然后取回我的原始模型。一切正常。
e = MyEstimator()
e.fit(x, y)
with open("myfile.pk", mode="wb") as f:
pickle.dump(e, f)
with open("myfile.pk", mode="rb") as f:
pickle.load(f) # Returns a copy of e
但是,当我尝试将MyEstimator 放入pipeline 并腌制GridSearchCV 的结果时,序列化不起作用。
s = GridSearchCV(Pipeline([
# ...
("estimator", MyEstimator())
# ...
]))
s.fit(x, y)
with open("myfile.pk", mode="wb") as f:
pickle.dump(s, f)
在pickle.dump 调用期间,我希望看到MyEstimator.__getstate__ 被一个合适的self.model_ 对象调用。 (当我在网格搜索之外自行序列化模型时会发生这种情况。)而self.model_ 是None,所以我无法序列化网格搜索生成的best_estimator_。
看起来网格搜索序列化正在实例化一个新的MyEstimator 对象,而不是使用管道中的那个。这对我来说似乎是错误的。我查看了 scikit-learn 代码,但看不到这是在哪里发生的。
这是 scikit-learn 中的错误,还是我做错了什么?
(注意:keras 确实有一个 wrapper layer 可以将一些 keras 模型转换为 scikit-learn 估计器,但是由于其他原因我不能在这里使用它,而且我不确定它不会有相同的问题。)
【问题讨论】:
-
大多数 scikit 中的模型评估工具,如
cross_val_score、GridSearchCV等在拟合之前克隆给定的估计器。在 GridSearchCV 中,您可以看到它克隆的 source code here。 -
指出特定的源代码行会有所帮助。我在调试器中逐步完成这个过程并且迷路了。我不明白为什么在搜索完成后会有任何没有
fit调用的克隆。
标签: serialization scikit-learn keras pickle