【问题标题】:What type is a sklearn model?sklearn 模型是什么类型的?
【发布时间】:2021-08-24 00:28:11
【问题描述】:

我正在编写一些代码来根据一些数据评估不同的 sklearn 模型。我正在使用类型提示,既是为了我自己的教育,也是为了帮助其他最终必须阅读我的代码的人。

我的问题是如何指定 sklearn 预测器的类型(例如LinearRegression())?

例如:

def model_tester(model : Predictor,
                 parameter: int
                 ) -> np.ndarray:
     """An example function with type hints."""

     # do stuff to model 

     return values

我看到typing library 可以制作新类型或者我可以使用TypeVar 来做:

Predictor = TypeVar('Predictor') 

但如果 sklearn 模型已经有常规类型,我不想使用它。

检查 LinearRegression() 的类型产生:

 sklearn.linear_model.base.LinearRegression

这显然是有用的,但前提是我对线性回归模型感兴趣。

【问题讨论】:

    标签: python scikit-learn


    【解决方案1】:

    从 Python 3.8 开始(或更早版本使用typing-extensions),您可以使用typing.Protocol。使用协议,您可以使用名为 structural subtyping 的概念来准确定义类型的预期结构:

    from typing import Protocol
    # from typing_extensions import Protocol  # for Python <3.8
    
    class ScikitModel(Protocol):
        def fit(self, X, y, sample_weight=None): ...
        def predict(self, X): ...
        def score(self, X, y, sample_weight=None): ...
        def set_params(self, **params): ...
    

    然后您可以将其用作类型提示:

    def do_stuff(model: ScikitModel) -> Any:
        model.fit(train_data, train_labels)  # this type checks 
        score = model.score(test_data, test_labels)  # this type checks
        ...
    

    【讨论】:

      【解决方案2】:

      我认为所有模型都继承自的最通用的类​​是sklearn.base.BaseEstimator

      如果您想更具体,可以使用sklearn.base.ClassifierMixinsklearn.base.RegressorMixin

      所以我会这样做:

      from sklearn.base import RegressorMixin
      
      
      def model_tester(model: RegressorMixin, parameter: int) -> np.ndarray:
           """An example function with type hints."""
      
           # do stuff to model 
      
           return values
      

      我不是类型检查方面的专家,如果这不正确,请纠正我。

      【讨论】:

      • 感谢您的回答。我尝试了 BaseEstimator 和 ClassifierMixin。但是当我调用 self.estimator.fit 时,我的 IDE(Pycharm)抱怨它找不到“fit”属性。这是正确的。这些类没有实现 fit。它是为每个估计器单独实现的(例如 LogisticRegression)。有谁知道应该采用 scikit-learn 估计器的参数的正确类型提示是什么?
      【解决方案3】:

      一个好的解决方法是创建您自己的自定义类型提示类(使用联合),其中包括您常用的所有模型。它需要更多的努力,但可以让您更加具体并与 PyCharm 一起使用。

      ModelRegressor = Union[LinearRegression, DecisionTreeRegressor, RandomForestRegressor, SVR]
      
      def foo(model: ModelRegressor):
          do_something
      

      【讨论】:

        【解决方案4】:

        您可以在任何地方将类型设置为sklearn.pipeline.Pipeline。这可能是一种无需创建额外实体的解决方案。

        【讨论】:

          猜你喜欢
          • 2019-05-11
          • 2018-03-17
          • 1970-01-01
          • 1970-01-01
          • 2018-12-28
          • 1970-01-01
          • 2018-07-14
          • 1970-01-01
          • 2019-08-10
          相关资源
          最近更新 更多