UPD: 2021-05-10
对于sklearn >= 0.20,我们可以使用sklearn.compose.ColumnTransformer
这是small example:
导入和数据加载
# Author: Pedro Morales <part.morales@gmail.com>
#
# License: BSD 3 clause
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
np.random.seed(0)
# Load data from https://www.openml.org/d/40945
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
使用ColumnTransformer进行管道感知数据预处理:
numeric_features = ['age', 'fare']
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())])
categorical_features = ['embarked', 'sex', 'pclass']
categorical_transformer = OneHotEncoder(handle_unknown='ignore')
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer, categorical_features)])
分类
# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
clf = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', LogisticRegression())])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=0)
clf.fit(X_train, y_train)
print("model score: %.3f" % clf.score(X_test, y_test))
旧答案:
假设你有以下 DF:
In [163]: df
Out[163]:
a b c d
0 aaa 1.01 xxx 111
1 bbb 2.02 yyy 222
2 ccc 3.03 zzz 333
In [164]: df.dtypes
Out[164]:
a object
b float64
c object
d int64
dtype: object
您可以找到所有数字列:
In [165]: num_cols = df.columns[df.dtypes.apply(lambda c: np.issubdtype(c, np.number))]
In [166]: num_cols
Out[166]: Index(['b', 'd'], dtype='object')
In [167]: df[num_cols]
Out[167]:
b d
0 1.01 111
1 2.02 222
2 3.03 333
并将StandardScaler 仅应用于这些数字列:
In [168]: scaler = StandardScaler()
In [169]: df[num_cols] = scaler.fit_transform(df[num_cols])
In [170]: df
Out[170]:
a b c d
0 aaa -1.224745 xxx -1.224745
1 bbb 0.000000 yyy 0.000000
2 ccc 1.224745 zzz 1.224745
现在您可以“一个热编码”分类(非数字)列...