【问题标题】:cross validation and text classification for imbalanced data不平衡数据的交叉验证和文本分类
【发布时间】:2018-11-15 16:16:01
【问题描述】:

我是 NLP 新手,我正在尝试构建一个文本分类器,但我的数据目前不平衡。最高类别有多达 280 个条目,而最低类别有 30 个条目。 我正在尝试对当前数据使用交叉验证技术,但是在寻找了几天之后我无法实现它。它看起来很简单,但我仍然无法实现它。这是我的代码

y = resample.Subsystem
X = resample['new description']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)
from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(X_train)
X_train_counts.shape
from sklearn.feature_extraction.text import TfidfTransformer
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
X_train_tfidf.shape
#SVM
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
text_clf_svm = Pipeline([('vect', CountVectorizer(stop_words='english')),('tfidf', TfidfTransformer()),('clf-svm', SGDClassifier(loss='hinge', penalty='l2',alpha=1e-3, n_iter=5, random_state=42)),])
text_clf_svm.fit(X_train, y_train)
predicted_svm = text_clf_svm.predict(X_test)
print('The best accuracy is : ',np.mean(predicted_svm == y_test))

我已经进一步做了一些 gridsearch 和 Stemmer,但现在我将对此代码进行交叉验证。我已经很好地清理了数据,但我仍然获得 60% 的准确度 任何帮助将不胜感激

【问题讨论】:

    标签: python-3.x machine-learning svm pipeline cross-validation


    【解决方案1】:

    尝试进行过采样或欠采样。由于数据高度不平衡,因此对具有更多数据点的类有更多的偏见。在过采样/欠采样之后,偏差会非常小,准确度会提高。

    您可以使用 MLP 代替 SVM。即使数据不平衡,它也能提供良好的结果。

    【讨论】:

    • 我对数据进行了过采样,但这是我想避免的,因为那只是复制数据,我没有从中学到任何东西
    • 尝试欠采样但是你的数据点太少了,还是可以试试。过采样会提高准确性吗?
    • 确实如此,但我还是想避免它,我觉得这不是一个好方法
    【解决方案2】:
    from sklearn.model_selection import StratifiedKFold
    skf = StratifiedKFold(n_splits=5, random_state=None)
    # X is the feature set and y is the target
    from sklearn.model_selection import RepeatedKFold 
    kf = RepeatedKFold(n_splits=20, n_repeats=10, random_state=None) 
    
    for train_index, test_index in kf.split(X):
      #print("Train:", train_index, "Validation:",test_index)
      X_train, X_test = X[train_index], X[test_index] 
      y_train, y_test = y[train_index], y[test_index]
    

    【讨论】:

      猜你喜欢
      • 2021-10-08
      • 2018-07-26
      • 2019-06-04
      • 2015-12-13
      • 2016-10-27
      • 1970-01-01
      • 2017-11-17
      • 2019-10-15
      • 2019-08-21
      相关资源
      最近更新 更多