【问题标题】:Customizing Random Forest classifier sklearn自定义随机森林分类器 sklearn
【发布时间】:2019-11-16 07:47:24
【问题描述】:

出于个人目的,我正在尝试从 sklearn 修改 Random Forest Classifier 类以实现我的预期。基本上,我正在尝试让我的随机森林树采用一些预定义的特征和案例子样本,因此我正在修改默认类。我正在尝试继承原始sklearn的所有方法和结构,以便我自定义的随机森林类的fit方法可以采用sklearn的原始参数

例如,我希望我的自定义类能够采用与原始 fit 方法相同的参数:

clf = RandomForestClassifier(n_estimators=10, max_depth=2, random_state=None, max_features=None...)


clf = Customized_RF(n_estimators=10, max_depth=2, random_state=None, max_features=None...)

但我在执行此操作时遇到了一些困难,具体来说,它似乎与 super().__init__ 定义有关,我收到以下错误:TypeError: object.__init__() takes no arguments

我遵循 github 存储库作为指导

Rf class

我是不是做错了什么或遗漏了一些明显的步骤?

这是我目前的方法:

import numpy as np
from sklearn.tree import DecisionTreeClassifier

class Customized_RF:
    def __init__(self, n_estimators=10, criterion='gini', max_depth=None, random_state=None):

        super().__init__(base_estimator=DecisionTreeClassifier(),
                         n_estimators=n_estimators,
                         estimator_params=("criterion", "max_depth")) # Here's where the error happens

        self.n_estimators = n_estimators

        if random_state is None:
            self.random_state = np.random.RandomState()
        else:
            self.random_state = np.random.RandomState(random_state)

        self.criterion = criterion
        self.max_depth = max_depth

    def fit(self, X, y, max_features=None, cutoff=None, bootstrap_frac=0.8):
        """
        max_features: number of features that each estimator will use,
                      including the fixed features.

        bootstrap_frac: the size of bootstrap sample that each estimator will use.

        cutoff: index feature number from which starting the features subsampling selection. Subsampling for each tree will be done retrieven a random number of features before and after the cutoff. Assuming that the features matrix is not sorted or altered somehow (sparsed).

        """
        self.estimators = []
        self.n_classes  = np.unique(y).shape[0]

        if max_features is None:
            max_features = X.shape[1]  # if max_features is None select all features for every estimator like original

        if cutoff is None:
            cutoff = int(X.shape[1] / 2)  # pick the central index number of the x vector

        print('Cutoff x vector: '.format(cutoff))

        n_samples = X.shape[0]
        n_bs = int(bootstrap_frac*n_samples)  # fraction of samples to be used for every estimator (DT)

        for i in range(self.n_estimators):
                                    replace=False)

            feats_left = self.random_state.choice(cutoff + 1, int(max_features / 2), replace=False)  # inclusive cutoff
            feats_right = self.random_state.choice(range(cutoff + 1, X.shape[1]), int(max_features/2), replace=False)
            # exclusive cutoff

            feats = np.concatenate((feats_left, feats_right)).tolist()

            self.feats_used.append(feats)

            print('Chosen feature indexes for estimator number {0}: {1}'.format(i, feats))

            bs_sample = self.random_state.choice(n_samples, 
                                                 size=n_bs,
                                                 replace=True)

            dtc = DecisionTreeClassifier(random_state=self.random_state)
            dtc.fit(X[bs_sample][:, feats], y[bs_sample])
            self.estimators.append(dtc)

    def predict_proba(self, X):
        out = np.zeros((X.shape[0], self.n_classes))
        for i in range(self.n_estimators):
            out += self.estimators[i].predict_proba(X[:, self.feats_used[i]])
        return out / self.n_estimators

    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)

    def score(self, X, y):
        return (self.predict(X) == y).mean()

【问题讨论】:

    标签: python-3.x oop scikit-learn random-forest


    【解决方案1】:

    如果你想从另一个类派生出你自己的类,类定义需要引用base class,例如class MyClass(BaseClass)super() 然后引用基类。

    在您的情况下,缺少基类,Python 假定使用了泛型类 object

    您的问题不清楚您想要的基类是DecisionTreeClassifier 还是RandomForestClassifier。无论哪种情况,您都需要更改 __init__ 中使用的类参数。

    次要:检查replace=False) 行,它的语法无效。

    【讨论】:

    • 真的很有帮助,谢谢。我的基类应该是RandomForestClassifier 我猜,问题是我只是想让你修改 max_features 子采样,但不清楚应该在源代码中的何处添加此修改。顺便说一句,我看不到你在replace=False) 行中提到的无效语法
    • @KennethRivadeneiraGuadamud "如果您想说“谢谢”,请投票或接受该人的回答"(来自What should I do when someone answers my question?
    猜你喜欢
    • 2019-03-14
    • 2021-08-13
    • 2018-10-23
    • 2018-02-18
    • 2022-01-25
    • 2018-05-20
    • 2016-07-23
    • 2023-03-09
    • 2023-03-14
    相关资源
    最近更新 更多