【问题标题】:Get corresponding classes to predict_proba (GridSearchCV sklearn)获取对应的类到predict_proba (GridSearchCV sklearn)
【发布时间】:2020-02-14 01:39:59
【问题描述】:

我正在使用 GridSearchCV 和管道对一些文本文档进行分类。一个代码sn-p:

clf = Pipeline([('vect', TfidfVectorizer()), ('clf', SVC())])
parameters = {'vect__ngram_range' : [(1,2)], 'vect__min_df' : [2], 'vect__stop_words' : ['english'],
                  'vect__lowercase' : [True], 'vect__norm' : ['l2'], 'vect__analyzer' : ['word'], 'vect__binary' : [True], 
                  'clf__kernel' : ['rbf'], 'clf__C' : [100], 'clf__gamma' : [0.01], 'clf__probability' : [True]} 
grid_search = GridSearchCV(clf, parameters, n_jobs = -2, refit = True, cv = 10)
grid_search.fit(corpus, labels)

我的问题是,当使用grid_serach.predict_proba(new_doc) 然后想找出概率与grid_search.classes_ 对应的类时,我收到以下错误:

AttributeError: 'GridSearchCV' 对象没有属性 'classes_'

我错过了什么?我想如果管道中的最后一个“步骤”是一个分类器,那么 GridSearchCV 的返回也是一个分类器。因此,可以使用该分类器的属性,例如类_。

【问题讨论】:

    标签: python scikit-learn text-classification


    【解决方案1】:

    如上面的 cmets 中所述,grid_search.best_estimator_.classes_ 返回了错误消息,因为它返回了一个没有属性 .classes_ 的管道。但是,通过首先调用管道的步骤分类器,我能够使用 classes 属性。这是解决方案

    grid_search.best_estimator_.named_steps['clf'].classes_
    

    【讨论】:

      【解决方案2】:

      试试grid_search.best_estimator_.classes_

      GridSearchCV 的返回是一个 GridSearchCV 实例,它本身并不是一个真正的估计器。相反,它会为它尝试的每个参数组合实例化一个新的估计器(请参阅the docs)。

      你可能认为返回值是一个分类器,因为你可以在refit=True时使用predictpredict_proba等方法,但GridSearchCV.predict_proba实际上看起来像(来自源的剧透):

      def predict_proba(self, X):
          """Call predict_proba on the estimator with the best found parameters.
          Only available if ``refit=True`` and the underlying estimator supports
          ``predict_proba``.
          Parameters
          -----------
          X : indexable, length n_samples
              Must fulfill the input assumptions of the
              underlying estimator.
          """
          return self.best_estimator_.predict_proba(X)
      

      希望这会有所帮助。

      【讨论】:

      • “grid_search.best_estimator_.classes_”不起作用。我收到一条错误消息,说管道没有名为 classes_ 的属性。但是,我设法找到了解决方案(见答案)。
      • 好的。我以为会是这种情况,但事实证明它对我有用,举一个类似于你的例子。 grid_search.best_estimator_ 是一个 Pipeline 对象,但我仍然可以得到 grid_search.best_estimator_.classes_。我正在使用开发版本。或者,您可以使用 steps 属性访问管道的每个步骤:dict(grid_search.best_estimator_.steps)["clf"].classes_ 应该适合您。
      • 好吧,也许这就是区别。我之前找到的解决方案几乎相同,我直接使用named_steps,而不是在使用steps属性时创建dict(见答案)。感谢您的帮助!
      猜你喜欢
      • 2021-11-12
      • 2018-03-03
      • 2017-03-29
      • 2018-01-05
      • 2020-06-24
      • 2018-06-08
      • 2016-05-20
      • 2018-08-14
      相关资源
      最近更新 更多