【发布时间】:2018-08-14 10:47:47
【问题描述】:
正如这里所指定的,https://stackoverflow.com/a/35662770/5757129,我存储了我的第一个模型的系数和截距。稍后,我将它们作为初始化程序传递给我的第二个 fit() ,如下所示,用于在旧模型之上学习新数据。
from sklearn import neighbors, linear_model
import numpy as np
import pickle
import os
def train_data():
x1 = [[8, 9], [20, 22], [16, 18], [8,4]]
y1 = [0, 1, 2, 3]
#classes = np.arange(10)
#sgd_clf = linear_model.SGDClassifier(learning_rate = 'constant', eta0 = 0.1, shuffle = False, n_iter = 1,warm_start=True)
sgd_clf = linear_model.SGDClassifier(loss="hinge",max_iter=10000)
sgd_clf.fit(x1,y1)
coef = sgd_clf.coef_
intercept = sgd_clf.intercept_
return coef, intercept
def train_new_data(coefs,intercepts):
x2 = [[18, 19],[234,897],[20, 122], [16, 118]]
y2 = [4,5,6,7]
sgd_clf1 = linear_model.SGDClassifier(loss="hinge",max_iter=10000)
new_model = sgd_clf1.fit(x2,y2,coef_init=coefs,intercept_init=intercepts)
return new_model
if __name__ == "__main__":
coefs,intercepts= train_data()
new_model = train_new_data(coefs,intercepts)
print(new_model.predict([[16, 118]]))
print(new_model.predict([[18, 19]]))
print(new_model.predict([[8,9]]))
print(new_model.predict([[20,22]]))
当我运行它时,我得到了仅从 new_model 训练的标签。例如,print(new_model.predict([[8,9]])) 必须将标签打印为 0,print(new_model.predict([[20,22]])) 必须将标签打印为 1。但它会打印从 4 到 7 匹配的标签。
我是否以错误的方式将 coef 和从旧模型拦截到新模型?
编辑:根据@vital_dml 答案重新构建问题
【问题讨论】:
标签: python machine-learning scikit-learn