【问题标题】:Monkey-Patching Magic Methods on scikit-learn Class Instancescikit-learn 类实例上的猴子修补魔术方法
【发布时间】:2015-11-10 15:21:23
【问题描述】:

我正在尝试构建一个名为 SafeModel 的工厂类,其 generate 方法接受一个 scikit-learn 类的实例,更改它的一些属性,并返回相同的实例.具体来说,对于这个例子,我想访问返回模型的 coef_ 属性,在案例 1)如果基础 scikit-learn 类包含 coef_,则返回该类'coef_,在案例 1 中2) 如果底层 scikit-learn 类包含feature_importances_,则返回该类'feature_importances

我已经成功地为 Python 类的实例修补了属性。我对 Python 类实例的猴子修补魔术方法的成功率较低。我的情况需要注意的是:属性coef_feature_importances 从未在 scikit-learn 类实例化时定义;相反,它们仅在对各自的类调用 fit 方法后定义。因此,我无法覆盖属性定义本身。

from types import MethodType


class SafeModel:

    FALLBACK_ATTRIBUTES = {
        'coef_': ['feature_importances_'],
    }

    @classmethod
    def generate(cls, model):
        safe_model = cls._secure_attributes(model)
        return safe_model

    @classmethod
    def _secure_attributes(cls, model):
        def __secure_getattr__(self, name):
            for fallback_attribute in cls.FALLBACK_ATTRIBUTES[name]:
                try:
                    return getattr(self, fallback_attribute)
                except:
                    continue
        model.__getattr__ = MethodType(__secure_getattr__, model)
        return model


    from sklearn.ensemble import RandomForestClassifier

    model = SafeModel.generate(RandomForestClassifier())
    model.coef_  # AttributeError: 'RandomForestClassifier' object has no attribute 'coef_'

【问题讨论】:

    标签: python scikit-learn monkeypatching magic-methods


    【解决方案1】:

    我无法查明您的代码有什么问题。不过,我找到了一个可能适用于您的用例的解决方法。
    我使用了不同的策略,因为我只是使用 SafeModel.__getattr__ 作为模型的 getattr 方法的包装器,而不是猴子修补。

    from sklearn.utils.validation import NotFittedError
    from sklearn.ensemble import RandomForestClassifier
    
    class SafeModel(object):
    
        def __init__(self, model):
            self.FALLBACK_ATTRIBUTES = {
            'coef_': ['feature_importances_'],
        }
            self.model = model
    
        def __getattr__(self, name):
            try:
                return getattr(self.model, name)
            except AttributeError:
                pass
            for fallback_attribute in self.FALLBACK_ATTRIBUTES[name]:
                try:
                    return getattr(self.model, fallback_attribute)
                except NotFittedError as e:
                    # NotFittedError inherits AttributeError.
                    raise e
                except AttributeError:
                    continue
            else:
                raise AttributeError(
                    "{} object has no attribute {}.".format(
                        self.__class__.__name__, name) + 
                    " Could not retrieve any fallback attribute.")                    
    
    
    model = SafeModel(RandomForestClassifier())
    model.coef_   
    

    输出:

    NotFittedError: Estimator not fitted, call `fit` before `feature_importances_`.
    

    请注意,这是正常行为,正如您所提到的,在您适应随机森林之前,您无法访问 feature_importances_

    诚然,这里的异常捕获相当脆弱(您需要添加一堆可能会引发的异常),但是如果您在尝试访问应该是的属性时不关心引发正确的异常很好。

    让我知道这是否适合您。如果您发现您发布的代码发生了什么,我也会对解释感兴趣!

    【讨论】:

    • 不幸的是,我希望将SafeModel 用作工厂,例如model = SafeModel.generate(RandomForestClassifier()),它返回一个修改后的sk-learn 对象实例。在上面,我们留下了一个 SafeModel 的实例。
    • 你不能把它用作装饰器吗?那么您的模型不是SafeModel 实例,而是用__getattr__ 装饰
    • @cavaunpeu 好的。仅供参考,我试图将打印语句放入您的“secure_getattr”中,但它们从未显示。我什至无法在 getattr 上应用一个简单的装饰器。 Challensois:如果你能做到,我很乐意看到解决方案。
    • @user3914041 我尝试了同样的方法,但无济于事。我认为 Python 调用魔术方法的方式与调用非魔术方法(通过访问相关实例的 dict 属性)完全不同。
    • @cavaunpeu 也许,我能够使用您的代码覆盖 fit。虽然无法使魔术方法起作用。
    猜你喜欢
    • 2015-03-23
    • 1970-01-01
    • 2011-03-20
    • 2021-03-10
    • 2016-10-30
    • 2010-10-01
    • 2012-08-19
    • 2022-11-24
    • 2016-11-27
    相关资源
    最近更新 更多