【发布时间】:2021-07-29 02:19:31
【问题描述】:
我正在使用 sklearn 估算器,它继承自 sklearn.base.BaseEstimator 并具有相当标准的界面。我想要做的一个例子是覆盖 .fit() 和 .predict() 方法来回归对数转换的目标,如下所示:
Estimator = sklearn.some_regression_estimator
class LogFit(Estimator):
"""subclass the sklearn regression estimator to fit and predict using
log-transformed target variable
"""
def __init__(self, **kwargs):
super().__init__(kwargs)
def fit(X, y=None, **kwargs):
super().fit(X, np.log(y), **kwargs)
return self
def predict(X):
return np.exp(super().predict(X))
我不一定事先知道将使用哪个估算器,只知道它会根据 sklearn 估算器约定运行。我也不想为每个可能的估计器重新编写上述子类,并且多重继承似乎不正确,因为LogFit 的每个实例都仅从单个父级继承。
我知道我可以编写一个包装类(然后使用覆盖的 fit() 和 predict() 方法子类 it),例如:
class EstimatorWrapper():
"""Wrapper class that has an estimator as a property"""
def __init__(self, estimator_instance):
self.estimator = estimator_instance
def fit(self, X, y=None, **kwargs):
self.fit(X, y, **kwargs)
return self
...
但在这一点上,我现在必须负责确保 EstimatorWrapper 类的行为就像基本估计器类一样,这样我就可以在不知道 sklearn 的其余机器知道的情况下使用 fit() 和 predict() 的 LogFit 版本区别。再说一次,如果我不知道每个可能的 estimator_instance 中存在哪些特定的方法/属性,似乎我必须以某种方式破解 EstimatorWrapper() 来动态定义其属性,而我真正想做的只是调整fit() 和 predict() 函数的行为。
我是否缺少一种编写子类的简单方法,该子类在实例化之前不知道其父类,还是根本不允许这样做?我找不到任何关于如何做前者的例子
【问题讨论】:
标签: python class inheritance scikit-learn subclass