【问题标题】:Scikit learn fit estimator with predefined number of classes具有预定义类数的 Scikit 学习拟合估计器
【发布时间】:2017-07-11 20:41:53
【问题描述】:

所以,我需要使用 scikit-learn 中的一些估计器,即 LogisticRegression 和 SVM,但是我有一个问题,我有一个非常不平衡的数据集,需要运行 Kfold 交叉验证。问题是有时我适合的折叠只能有一个可用的目标类。我想知道这些估计器是否有任何方法可以预定义类的数量,可能类似于向它们传递目标的单热编码表示,如果所有示例都来自一个类并不重要,形状目标矩阵已经定义了类的数量。

有没有办法用 scikit-learn 做到这一点?也许与另一个图书馆?我知道这两种算法使用 liblinear,也许在这种情况下我可以使用一些接口。

无论如何,谢谢你的时间。

编辑:StratifiedFold 交叉验证对我没有用,因为有时我的出现次数少于折叠次数。例如。可能会发生我有一个包含 50 个实例和 3 个类的数据集,但是 46 个可以属于一个类,2 个属于第二类和 2 个属于第三类,虽然我可以进行 3 折交叉验证,但我通常需要结果比这更多的折叠,加上即使是 3 折叠仍然留下一个类是唯一可用于一个折叠的情况。

【问题讨论】:

  • 请更准确。我不明白你这个问题。通常人们使用Stratified folds in these cases
  • 原始问题中添加了一个版本。在我的特定情况下,分层折叠对我没有用。
  • 问题不在于库,而在于您的问题设置。如果你在一个折叠中有一个类没有什么可学的(当任务是常规分类时),那么它就是病态问题。收集更多数据,重新定义测试(kfold 不适用)或研究不同的学习范式。如果您不关心并且只想强制库工作 - 只需检测标签集何时为单例并将所有预测与该标签相等,因为这是此类 dsta 的唯一有效模型。

标签: python machine-learning scikit-learn cross-validation


【解决方案1】:

说您需要收集更多数据的评论可能是正确的。但是,如果您认为您的模型有足够的数据来学习有用的东西,您可以对少数类进行过度采样(或者可能对多数类进行采样,但这听起来像是过度采样的问题)。数据集中只有一个类使您的模型几乎不可能了解该类的任何内容。

这里有一些指向 python 中过采样和欠采样库的链接。著名的不平衡学习库很棒。

https://imbalanced-learn.org/en/stable/generated/imblearn.under_sampling.RandomUnderSampler.html

https://imbalanced-learn.org/en/stable/generated/imblearn.over_sampling.RandomOverSampler.html

https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.SMOTE.html

https://imbalanced-learn.readthedocs.io/en/stable/auto_examples/over-sampling/plot_comparison_over_sampling.html#sphx-glr-auto-examples-over-sampling-plot-comparison-over-sampling-py

https://imbalanced-learn.org/en/stable/combine.html

您的案例听起来很适合 SMOTE。你还提到你想改变比率。 imblearn.over_sampling.SMOTE 中有一个名为 ratio 的参数,您可以在其中传递字典。您也可以使用百分比来执行此操作(请参阅文档)。

SMOTE 使用 K-Nearest-Neighbors 算法生成与采样数据不足的数据点“相似”的数据点。这是一种比传统过采样更强大的算法,因为当您的模型获得训练数据时,它有助于避免您的模型记忆特定示例的关键点的问题。相反,smote 创建了一个“相似”的数据点(可能在多维空间中),因此您的模型可以更好地学习泛化。

注意:不要对完整数据集使用 SMOTE,这一点至关重要。您必须仅在训练集上使用 SMOTE(即拆分后),然后在验证集和测试集上进行验证,以查看您的 SMOTE 模型是否优于其他模型。如果您不这样做,将会出现数据泄漏,并且您将获得一个与您想要的模型甚至不太相似的模型。

from collections import Counter
from imblearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
import numpy as np
from xgboost import XGBClassifier
import warnings

warnings.filterwarnings(action='ignore', category=DeprecationWarning)
sm = SMOTE(random_state=0, n_jobs=8, ratio={'class1':100, 'class2':100, 'class3':80, 'class4':60, 'class5':90})
X_resampled, y_resampled = sm.fit_sample(X_normalized, y)

print('Original dataset shape:', Counter(y))
print('Resampled dataset shape:', Counter(y_resampled))

X_train_smote, X_test_smote, y_train_smote, y_test_smote = train_test_split(X_resampled, y_resampled)
X_train_smote.shape, X_test_smote.shape, y_train_smote.shape, y_test_smote.shape, X_resampled.shape, y_resampled.shape

smote_xgbc = XGBClassifier(n_jobs=8).fit(X_train_smote, y_train_smote)

print('TRAIN')
print(accuracy_score(smote_xgbc.predict(np.array(X_train_normalized)), y_train))
print(f1_score(smote_xgbc.predict(np.array(X_train_normalized)), y_train))

print('TEST')
print(accuracy_score(smote_xgbc.predict(np.array(X_test_normalized)), y_test))
print(f1_score(smote_xgbc.predict(np.array(X_test_normalized)), y_test))

【讨论】:

    猜你喜欢
    • 2018-02-04
    • 1970-01-01
    • 2015-06-07
    • 2019-02-28
    • 2016-03-16
    • 2014-12-06
    • 2016-05-06
    • 2014-01-03
    • 2014-11-15
    相关资源
    最近更新 更多