【问题标题】:Override method for a collection of classes implementing an interface实现接口的类集合的重写方法
【发布时间】:2015-11-27 08:43:50
【问题描述】:

我正在使用 scikit-learn 并正在构建管道。构建管道后,我将使用 GridSearchCV 来查找最佳模型。我正在处理文本数据,所以我正在尝试不同的词干分析器。我创建了一个名为 Preprocessor 的类,它接受一个词干分析器和向量化器类,然后尝试覆盖向量化器的方法 build_analyzer 以合并给定的词干分析器。但是,我看到 GridSearchCV 的 set_params 只是直接访问实例变量——即它不会像我一直在做的那样用新的分析器重新实例化矢量化器:

class Preprocessor(object):
    # hard code the stopwords for now
    stopwords = nltk.corpus.stopwords.words()

    def __init__(self, stemmer_cls, vectorizer_cls):
        self.stemmer = stemmer_cls()
        analyzer = self._build_analyzer(self.stemmer, vectorizer_cls)
        self.vectorizer = vectorizer_cls(stopwords=stopwords,
                                         analyzer=analyzer,
                                         decode_error='ignore')

    def _build_analyzer(self, stemmer, vectorizer_cls):
        # analyzer tokenizes and lowercases
        analyzer = super(vectorizer_cls, self).build_analyzer()
        return lambda doc: (stemmer.stem(w) for w in analyzer(doc))

    def fit(self, **kwargs):
        return self.vectorizer.fit(kwargs)

    def transform(self, **kwargs):
        return self.vectorizer.transform(kwargs)

    def fit_transform(self, **kwargs):
        return self.vectorizer.fit_transform(kwargs)

所以问题是:如何为传入的所有矢量化器类覆盖 build_analyzer?

【问题讨论】:

    标签: python machine-learning overriding scikit-learn porter-stemmer


    【解决方案1】:

    是的,GridSearchCV 直接设置实例字段,然后在字段发生变化的分类器上调用 fit。

    scikit-learn 中的每个分类器都是以这样的方式构建的,__init__ 仅设置参数字段,并且进一步工作所需的所有依赖对象(如在您的情况下调用 _build_analyzer)仅在 fit 方法中构建。您必须添加存储 vectorizer_cls 的附加字段,然后您必须在 fit 方法中从 vectorized_cls 和 stemmer_cls 对象构造依赖。

    类似:

    class Preprocessor(object):
        # hard code the stopwords for now
        stopwords = nltk.corpus.stopwords.words()
    
        def __init__(self, stemmer_cls, vectorizer_cls):
            self.stemmer_cls = stemmer_cls
            self.vectorizer_cls = vectorizer_cls
    
        def _build_analyzer(self, stemmer, vectorizer_cls):
            # analyzer tokenizes and lowercases
            analyzer = super(vectorizer_cls, self).build_analyzer()
            return lambda doc: (stemmer.stem(w) for w in analyzer(doc))
    
        def fit(self, **kwargs):
            analyzer = self._build_analyzer(self.stemmer_cls(), vectorizer_cls)
            self.vectorizer_cls = vectorizer_cls(stopwords=stopwords,
                                             analyzer=analyzer,
                                             decode_error='ignore')
    
            return self.vectorizer_cls.fit(kwargs)
    
        def transform(self, **kwargs):
            return self.vectorizer_cls.transform(kwargs)
    
        def fit_transform(self, **kwargs):
            return self.vectorizer_cls.fit_transform(kwargs)
    

    【讨论】:

      猜你喜欢
      • 2019-01-24
      • 1970-01-01
      • 2023-01-23
      • 1970-01-01
      • 2014-04-12
      • 2012-03-22
      • 2011-05-16
      • 2014-10-25
      • 1970-01-01
      相关资源
      最近更新 更多