【问题标题】:How can I use sgdclassifier hinge loss with Gridsearchcv using log loss metric?如何使用对数损失度量将 sgdclassifier 铰链损失与 Gridsearchcv 结合使用?
【发布时间】:2019-09-17 12:14:11
【问题描述】:

我知道 sgdclassifier 铰链损失不支持概率估计。那么在使用 log_loss 指标时如何将它与 GridSearchCV 一起使用呢?

clf = SGDClassifier(loss='hinge')

grid_params = {'alpha': [0.0001, 0.001, 0.01]}
grid_search = GridSearchCV(clf, grid_params, scoring='neg_log_loss')
grid_search.fit(X_train, y_train)

返回:

AttributeError:概率估计不适用于 损失='铰链'

我有什么办法可以完成这项工作吗?

【问题讨论】:

    标签: python scikit-learn grid-search


    【解决方案1】:

    将损失从铰链更改为对数正在将算法从 SVM 更改为逻辑回归,所以我认为这是不可能的。

    但是,您可以将 SGDClassifier 设置为 Scikit-learn 的 CalibratedClassifierCV 中的基本估计器,这将生成概率估计。

    这是一个例子:

    from sklearn.calibration import CalibratedClassifierCV
    from sklearn.linear_model import SGDClassifier
    from sklearn.model_selection import GridSearchCV, train_test_split
    from sklearn.datasets import load_iris
    
    # load some example data
    data = load_iris()
    X = data['data']
    y = data['target']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    
    clf = SGDClassifier(loss='hinge', max_iter=100)
    calibrated_clf = CalibratedClassifierCV(base_estimator=clf, method='sigmoid', cv=3)  # set the SGD classifier as the base estimator
    
    
    grid_params = {'base_estimator__alpha': [0.0001, 0.001, 0.01]}  # note 'base_estimator__' in the params because you want to change params in the SGDClassifier
    grid_search = GridSearchCV(estimator=calibrated_clf, param_grid=grid_params, cv=3)
    grid_search.fit(X_train, y_train)
    
    print(grid_search.best_params_)
    
    {'base_estimator__alpha': 0.0001}
    

    现在用最佳参数拟合校准分类器:

    calibrated_clf.set_params(**grid_search.best_params_)
    calibrated_clf.fit(X_train, y_train)
    preds = calibrated_clf.predict_proba(X_test)
    print(preds)
    
    # probabilities for each of the 3 classes:
    
    array([[7.62825746e-02, 5.24891243e-01, 3.98826183e-01],
           [9.24810700e-01, 7.50659865e-02, 1.23313813e-04],
           [8.40690799e-01, 1.59138563e-01, 1.70637465e-04],
           [7.10696359e-01, 2.88969750e-01, 3.33891072e-04],
           [7.99360835e-02, 7.83076911e-01, 1.36987006e-01],
           [9.90417693e-03, 7.72846023e-02, 9.12811221e-01],
           [1.07116396e-02, 3.03030985e-01, 6.86257375e-01],
           [1.43944221e-02, 1.17223024e-01, 8.68382554e-01],
           [1.11659634e-01, 7.35051942e-01, 1.53288424e-01],
           [8.30127745e-03, 1.39546231e-01, 8.52152492e-01],
           [2.07825315e-02, 1.56925620e-01, 8.22291849e-01],
           [8.88421387e-01, 1.11384933e-01, 1.93680314e-04],
           [6.90696963e-01, 3.09038629e-01, 2.64408097e-04],
           [1.26043359e-01, 5.78366890e-01, 2.95589750e-01],
           [3.83356263e-03, 4.06197230e-01, 5.89969207e-01],
           [7.78520570e-01, 2.21144460e-01, 3.34969184e-04],
           [5.11227086e-02, 6.32329915e-01, 3.16547377e-01],
           [8.24310445e-01, 1.75412791e-01, 2.76763715e-04],
           [3.50118697e-02, 3.91028064e-01, 5.73960067e-01],
           [1.23034113e-01, 7.32289832e-01, 1.44676055e-01],
           [3.44588463e-01, 5.92799831e-01, 6.26117056e-02],
           [2.67170305e-02, 5.78551461e-01, 3.94731509e-01],
           [5.92943916e-02, 5.57127843e-01, 3.83577765e-01],
           [7.16297083e-01, 2.83282184e-01, 4.20732771e-04],
           [7.82091800e-03, 1.30949377e-01, 8.61229705e-01],
           [1.70781668e-01, 5.47432635e-01, 2.81785697e-01],
           [8.38288358e-01, 1.61495161e-01, 2.16480625e-04],
           [2.11106665e-02, 4.66121567e-01, 5.12767766e-01],
           [9.20496389e-02, 6.29184167e-01, 2.78766194e-01],
           [1.29649784e-02, 2.73576019e-01, 7.13459002e-01]])
    

    【讨论】:

    • 我对如何为 SGD 铰链损失传递超参数感到困惑。从来不知道我们可以通过像base_estimator__alpha 这样的东西。非常感谢。
    猜你喜欢
    • 2019-09-09
    • 2019-06-08
    • 2019-04-14
    • 2021-09-09
    • 1970-01-01
    • 2021-04-13
    • 2020-05-20
    • 2019-06-25
    • 1970-01-01
    相关资源
    最近更新 更多