【发布时间】:2017-01-04 09:36:10
【问题描述】:
我正在学习包含此部分的教程:
>>> import numpy as np
>>> import pandas as pd
>>> from sklearn.feature_extraction.text import TfidfVectorizer
>>> from sklearn.linear_model.logistic import LogisticRegression
>>> from sklearn.cross_validation import train_test_split, cross_val_score
>>> df = pd.read_csv('data/sms.csv')
>>> X_train_raw, X_test_raw, y_train, y_test = train_test_split(df['message'], df['label'])
>>> vectorizer = TfidfVectorizer()
>>> X_train = vectorizer.fit_transform(X_train_raw)
>>> X_test = vectorizer.transform(X_test_raw)
>>> classifier = LogisticRegression()
>>> classifier.fit(X_train, y_train)
>>> precisions = cross_val_score(classifier, X_train, y_train, cv=5, scoring='precision')
>>> print 'Precision', np.mean(precisions), precisions
>>> recalls = cross_val_score(classifier, X_train, y_train, cv=5, scoring='recall')
>>> print 'Recalls', np.mean(recalls), recalls
然后我复制了一些修改:
ddir = (sys.argv[1])
df = pd.read_csv(ddir + '/SMSSpamCollection', sep='\t', quoting=csv.QUOTE_NONE, names=["label", "message"])
X_train_raw, X_test_raw, y_train, y_test = train_test_split(df['label'], df['message'])
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform(X_train_raw)
X_test = vectorizer.transform(X_test_raw)
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
precisions = cross_val_score(classifier, X_train, y_train, cv=5, scoring='precision')
recalls = cross_val_score(classifier, X_train, y_train, cv=5, scoring='recall')
print 'Precision', np.mean(precisions), precisions
print 'Recalls', np.mean(recalls), recalls
然而,尽管代码几乎没有差异,但书中的结果比我的要好得多:
书:Precision 0.992137651822 [ 0.98717949 0.98666667 1. 0.98684211 1. ]
Recall 0.677114261885 [ 0.7 0.67272727 0.6 0.68807339 0.72477064]
我的:Precision 0.108435683974 [ 2.33542342e-06 1.22271611e-03 1.68918919e-02 1.97530864e-01 3.26530612e-01]Recalls 0.235220281632 [ 0.00152053 0.03370787 0.125 0.44444444 0.57142857]
回到脚本看看哪里出了问题,我以为第 18 行:
X_train_raw, X_test_raw, y_train, y_test = train_test_split(df['label'], df['message'])
是罪魁祸首,将(df['label'], df['message']) 更改为(df['message'], df['label'])。但这给了我一个错误:
Traceback (most recent call last):
File "Chapter4[B-FLGTLG]C[Y-BCPM][G-PAR--[00].py", line 30, in <module>
precisions = cross_val_score(classifier, X_train, y_train, cv=5, scoring='precision')
File "/usr/local/lib/python2.7/dist-packages/sklearn/cross_validation.py", line 1433, in cross_val_score
for train, test in cv)
File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 800, in __call__
while self.dispatch_one_batch(iterator):
File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 658, in dispatch_one_batch
self._dispatch(tasks)
File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 566, in _dispatch
job = ImmediateComputeBatch(batch)
File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 180, in __init__
self.results = batch()
File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 72, in __call__
return [func(*args, **kwargs) for func, args, kwargs in self.items]
File "/usr/local/lib/python2.7/dist-packages/sklearn/cross_validation.py", line 1550, in _fit_and_score
test_score = _score(estimator, X_test, y_test, scorer)
File "/usr/local/lib/python2.7/dist-packages/sklearn/cross_validation.py", line 1606, in _score
score = scorer(estimator, X_test, y_test)
File "/usr/local/lib/python2.7/dist-packages/sklearn/metrics/scorer.py", line 90, in __call__
**self._kwargs)
File "/usr/local/lib/python2.7/dist-packages/sklearn/metrics/classification.py", line 1203, in precision_score
sample_weight=sample_weight)
File "/usr/local/lib/python2.7/dist-packages/sklearn/metrics/classification.py", line 984, in precision_recall_fscore_support
(pos_label, present_labels))
ValueError: pos_label=1 is not a valid label: array(['ham', 'spam'],
dtype='|S4')
这可能是什么问题?数据在这里:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection,以防有人想尝试。
【问题讨论】:
-
教程用来获取结果的数据是否相同?
-
是的。出处:
We will use the SMS Spam Classification Data Set from the UCI Machine Learning Repository. The dataset can be downloaded from http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection. First, let's explore the data set and calculate some basic summary statistics using pandas -
为什么要包含
sep="\t"参数(以及read_csv调用中的其他参数)?您是否检查过数据是否正确导入?如果教程使用相同的数据,但没有使用"\t",那么数据可能是逗号分隔的,而不是制表符分隔的。 -
实际上,那条线有问题,所以我用这里的一条替换了书中的一条:radimrehurek.com/data_science_python。
SMSSpamCollection中的行是空格分隔的,但是内容和sms.csv一样,我没有。
标签: python csv pandas dataframe scikit-learn