【问题标题】:(Python - sklearn) How to pass parameters to the customize ModelTransformer class by gridsearchcv(Python - sklearn)如何通过gridsearchcv将参数传递给自定义的ModelTransformer类
【发布时间】:2015-03-04 20:09:28
【问题描述】:

下面是我的管道,我似乎无法通过使用 ModelTransformer 类将参数传递给我的模型,我从链接 (http://zacstewart.com/2014/08/05/pipelines-of-featureunions-of-pipelines.html) 获取它

错误消息对我来说很有意义,但我不知道如何解决这个问题。知道如何解决这个问题吗?谢谢。

# define a pipeline
pipeline = Pipeline([
('vect', DictVectorizer(sparse=False)),
('scale', preprocessing.MinMaxScaler()),
('ess', FeatureUnion(n_jobs=-1, 
                     transformer_list=[
     ('rfc', ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1,  n_estimators=100))),
     ('svc', ModelTransformer(SVC(random_state=1))),],
                     transformer_weights=None)),
('es', EnsembleClassifier1()),
])

# define the parameters for the pipeline
parameters = {
'ess__rfc__n_estimators': (100, 200),
}

# ModelTransformer class. It takes it from the link
(http://zacstewart.com/2014/08/05/pipelines-of-featureunions-of-pipelines.html)
class ModelTransformer(TransformerMixin):
    def __init__(self, model):
        self.model = model
    def fit(self, *args, **kwargs):
        self.model.fit(*args, **kwargs)
        return self
    def transform(self, X, **transform_params):
        return DataFrame(self.model.predict(X))

grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1, refit=True)

错误信息: ValueError:估计器 ModelTransformer 的参数 n_estimators 无效。

【问题讨论】:

  • 感谢您的提问——我也有同样的问题。让我再问你一件事。你知道为什么 self.model.fit(*args, **kwargs) 有效吗?我的意思是你通常不会在调用 fit 方法时传递像 n_estimators 这样的超参数,但是在定义类实例时,例如 rfc=RandomForestClassifier(n_estimators=100), rfc.fit(X,y)
  • @drake,当你创建一个 ModelTransformer 实例时,你需要传入一个带有参数的模型。例如,ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1, n_estimators=100)))。而这里 self.model.fit(*args, **kwargs) 主要是指 self.model.fit(X, y)。
  • 谢谢,@nkhuyu。我知道它是这样工作的。我在问为什么。由于 self.model = 模型,self.model=RandomForestClassifier(n_jobs=-1, random_state=1, n_estimators=100)。我知道 *args 正在解包 (X, y),但我不明白为什么 self.model 已经知道超参数时在 fit 方法中需要 **kwargs。

标签: python-2.7 machine-learning parameter-passing scikit-learn cross-validation


【解决方案1】:

GridSearchCV 对嵌套对象有一个特殊的命名约定。在您的情况下ess__rfc__n_estimators 代表ess.rfc.n_estimators,并且根据pipeline 的定义,它指向属性n_estimators

ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1,  n_estimators=100)))

显然,ModelTransformer 实例没有这样的属性。

修复很简单:为了访问ModelTransformer 的底层对象,需要使用model 字段。所以,网格参数变成了

parameters = {
  'ess__rfc__model__n_estimators': (100, 200),
}

P.S.这不是您的代码的唯一问题。为了在 GridSearchCV 中使用多个作业,您需要使您正在使用的所有对象都可复制。这是通过实现方法get_paramsset_params来实现的,你可以从BaseEstimator mixin中借用它们。

【讨论】:

  • 你能稍微扩展一下这个 PS 吗?我想我有同样的问题,当我尝试将 gridsearchcv 与管道功能联合使用时,我收到错误 AttributeError: 'SelectColumns' object has no attribute 'get_params' 其中 SelectColumns 是我为管道编写的类。
  • @B_Miner,您应该从提供上述set_paramsget_paramsBaseEstimator 继承您的SelectColumns 类。或者,您可以实现自己的,但大多数时候您不想这样做。
  • 我正在寻找 BaseEstimatorMixin。我从 BaseEstimator 继承而来,它就像一个魅力,谢谢!
  • @ArtemSobolev 我正在做同样的事情。当我尝试将 cross_val_predict 或 gridsearch CV 与相同的管道一起使用时,出现错误“无法深度复制此模式对象”。您能否展示一下您是如何使用功能联合的?
猜你喜欢
  • 1970-01-01
  • 2019-08-04
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2017-07-06
  • 2013-06-03
  • 2021-08-18
相关资源
最近更新 更多