对分类问题进行简单可视化
导入数据
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
data=pd.read_csv('LogiReg_data.txt',header=None,names=['Exam 1','Exam 2','Admitted'])
data.head()
|
Exam 1 |
Exam 2 |
Admitted |
| 0 |
34.623660 |
78.024693 |
0 |
| 1 |
30.286711 |
43.894998 |
0 |
| 2 |
35.847409 |
72.902198 |
0 |
| 3 |
60.182599 |
86.308552 |
1 |
| 4 |
79.032736 |
75.344376 |
1 |
根据标签画出直方图
X=data.loc[:,data.columns!="Admitted"]
y=data.loc[:,data.columns=="Admitted"]
y_true=y.loc[y["Admitted"]==1,"Admitted"]
y_false=y.loc[y["Admitted"]==0,"Admitted"]
plt.bar([0,1],[len(y_true),len(y_false)],0.3)
plt.xticks([0,1])
plt.show()


根据样本分布画出散点图
x_1=data.loc[data["Admitted"]==1,"Exam 1"]
y_1=data.loc[data["Admitted"]==1,"Exam 2"]
x_2=data.loc[data["Admitted"]==0,"Exam 1"]
y_2=data.loc[data["Admitted"]==0,"Exam 2"]
plt.scatter(x_1,y_1,color='red')
plt.scatter(x_2,y_2,color="blue")
plt.xlabel("Exam 1")
plt.ylabel("Exam 2")
plt.title("Scatter")
plt.show()


训练模型
import matplotlib as mpl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import recall_score
X=data.loc[:,data.columns!="Admitted"]
Y=data.loc[:,data.columns=="Admitted"]
dr=DecisionTreeRegressor()
dr.fit(X,Y.values.ravel())
y=lr.predict(X)
score=recall_score(Y,y)
score
1.0
画出分类可视化图
N=500
M=500
x1_min,x2_min=X.min()
x1_max,x2_max=X.max()
t1=np.linspace(x1_min,x1_max,N)
t2=np.linspace(x2_min,x2_max,N)
x1,x2=np.meshgrid(t1,t2)
x_show=np.stack((x1.flat,x2.flat),axis=1)
y_predict=dr.predict(x_show)
cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0'])
cm_dark = mpl.colors.ListedColormap(['g', 'r'])
plt.pcolormesh(x1,x2,y_predict.reshape(x1.shape),cmap=cm_light)
x_1=data.loc[data["Admitted"]==1,"Exam 1"]
y_1=data.loc[data["Admitted"]==1,"Exam 2"]
x_2=data.loc[data["Admitted"]==0,"Exam 1"]
y_2=data.loc[data["Admitted"]==0,"Exam 2"]
plt.scatter(x_1,y_1,color='red')
plt.scatter(x_2,y_2,color="blue")
plt.xlabel("Exam 1")
plt.ylabel("Exam 2")
plt.title("Scatter")
plt.show()
plt.show()

