【问题标题】:Am I using Decorators correctly?我正确使用装饰器吗?
【发布时间】:2021-05-30 05:58:36
【问题描述】:

我不确定如何正确使用装饰器;我参考了Real PythonTry-Except for Multiple Methods。我正在编写一个线性回归类,我意识到你需要先调用fit,然后才能进行预测,或者我的类拥有的其他方法。但是当self._fitted 标志为False 时,定义每个引发错误的方法很麻烦。所以我求助于装饰器,我不确定我是否正确使用,因为它的行为确实符合我的要求,但是它忽略了任何其他形式的错误,如 ValueError 等。在这里寻求建议。

import functools
from sklearn.exceptions import NotFittedError


def NotFitted(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            raise NotFittedError

    return wrapper

class LinearRegression:
    def __init__(self, fit_intercept: bool = True):
        self.coef_ = None
        self.intercept_ = None
        self.fit_intercept = fit_intercept
        # a flag to turn to true once we called fit on the data
        self._fitted = False

def check_shape(self, X: np.array, y: np.array):
    # if X is 1D array, then it is simple linear regression, reshape to 2D
    # [1,2,3] -> [[1],[2],[3]] to fit the data
    if X is not None and len(X.shape) == 1:
        X = X.reshape(-1, 1)
    # self._features = X
    # self.intercept_ = y
    return X, y

def fit(self, X: np.array = None, y: np.array = None):
    X, y = self.check_shape(X, y)
    n_samples, n_features = X.shape[0], X.shape[1]
    if self.fit_intercept:
   
        X = np.c_[np.ones(n_samples), X]
    XtX = np.dot(X.T, X)
    XtX_inv = np.linalg.inv(XtX)
    XtX_inv_Xt = np.dot(XtX_inv, X.T)
    _optimal_betas = np.dot(XtX_inv_Xt, y)

    # set attributes from None to the optimal ones
    self.coef_ = _optimal_betas[1:]
    self.intercept_ = _optimal_betas[0]
    self._fitted = True

    return self

@NotFitted
def predict(self, X: np.array):
    """
    after calling .fit, you can continue to .predict to get model prediction
    """
    # if self._fitted is False:
    #     raise NotFittedError
    if self.fit_intercept:
        y_hat = self.intercept_ + np.dot(X, self.coef_)
    else:
        y_hat = self.intercept_
    return y_hat

【问题讨论】:

  • 如果您对这种方法的问题是包装的func 中的错误没有传播,您可以更改装饰器中的错误处理以重新抛出错误,而不是始终使用NotFittedError。但我很困惑:除了捕捉这些错误并将它们屏蔽为NotFittedError 之外,您的装饰器是否还有任何作用?我不认为_fitted 有人读过吗?
  • 是的,这可能是我想问的,我需要在装饰器中调用_fitted吗?

标签: python oop machine-learning scikit-learn decorator


【解决方案1】:

让我快速重复你想要做的事情,以确保我没有误解。 你想要一个装饰器@NotFitted,这样你用它注释的每个函数都会首先检查self._fitted是否是True,如果它是False,则以NotFittedError失败而不是执行函数。

通过查看this question,您可以了解如何将其他参数传递给装饰器。
我不习惯使用装饰器,所以我不得不快速测试一下,看看你的代码中发生了什么——为什么def wrapper 不需要参数self

>>> def deco1(func):
...   def wrapper(*args, **kwargs):
...     print("Args are {}".format(args))
...   return wrapper

>>> class Foo(object):
...   @deco1
...   def meth(self, a):
...     print("a: "+a)

>>> f = Foo()
>>> f.meth("hello")
Args are (<__main__.Foo object at 0x7f37676a4128>, 'hello')

正如您在此处看到的,wrapper 打印的第一个参数实际上是self*args 只是将所有非关键字参数收集到一个元组中,包括self,这是这里的第一个参数。如果我们愿意,我们可以通过 def wrapper(self, *args, **kwargs) 来更明确(请参阅链接的问题)。

我需要在装饰器中调用_fitted吗?

是的,因为self._fitted 是您跟踪它是否已安装的方式。你可以通过*args 的第一个元素通过args[0]._fitted 访问它。但我更喜欢明确地传递自我。无论哪种方式,您都可以在wrapper 内部检查self._fitted 是否为True,如果不是则失败。所以我定义了这个例子:

#!/bin/env/python3
# Declaring my own NotFittedError, because I don't want to
# from sklearn.exceptions import NotFittedError
# just for this small example.

class NotFittedError (Exception):
    pass

def NotFitted ( foo ):
    def wrapper ( self, *args, **kwargs ):
        if not self._fitted:
            raise NotFittedError()
        else:
            foo ( self, *args, **kwargs )

    return wrapper

class Foo() :
    # Set self._fitted to false just to be explicit.
    # The initial value should be False anyway.
    def __init__(self):
        self._fitted = False

    def fit(self):
        self._fitted = True

    @NotFitted
    def predict(self, X):
        # code here that assumes fit was already called
        print ( "Successfully Predicted!" )

现在我们可以使用它了。在下面的 sn-p 中,我将它导入为 tmp,因为我将它放在一个名为 tmp.py 的文件中。您不必这样做,因为您将所有内容都放在同一个文件中。

>>> import tmp
>>> f = tmp.Foo()
>>> f.predict("a")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/generic/Downloads/tmp.py", line 12, in wrapper
    raise NotFittedError()
tmp.NotFittedError
>>> f.fit()
>>> f.predict("a")
Successfully Predicted!

一些进一步的cmets:

【讨论】:

    猜你喜欢
    • 2011-12-15
    • 1970-01-01
    • 2014-06-08
    • 1970-01-01
    • 2020-03-31
    • 2023-03-16
    • 2011-09-08
    • 1970-01-01
    相关资源
    最近更新 更多