In [1]:
# Importing the libraries 导入库
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 使图像能够调整
%matplotlib notebook
#中文字体显示
plt.rc('font', family='SimHei', size=8)
In [2]:
dataset = pd.read_csv('Mall_Customers.csv') # 寻找目标用户,根据年收入和购物指数进行聚类分析
dataset
Out[2]:
In [4]:
X = dataset.iloc[:, 3:5].values # 顾客年收入,顾客购物指数
X
Out[4]:
In [12]:
from sklearn.cluster import KMeans
wcss = []
for i in range(1,11):
kmeans = KMeans(n_clusters = i,max_iter = 300,n_init = 10,init= 'k-means++',random_state = 0)
kmeans.fit(X)
wcss.append(kmeans.inertia_) # 计算组间距离
plt.plot(range(1,11),wcss)
plt.title(u'手肘图像')
plt.xlabel(u'集群数')
plt.ylabel(u'组间距离')
plt.show()
In [18]:
kmeans = KMeans(n_clusters = 5,max_iter = 300,n_init = 10,init= 'k-means++',random_state = 0) # init= 'k-means++' 避免了初始化陷阱
y_kmeans = kmeans.fit_predict(X)
print(y_kmeans ) # 打印分类结果
print(kmeans.cluster_centers_)# 打印聚类后的中心点
In [15]:
plt.scatter(X[y_kmeans == 0,0],X[y_kmeans == 0,1], s=100, c = 'red',label='Cluster 0') # 理性
plt.scatter(X[y_kmeans == 1,0],X[y_kmeans == 1,1], s=100, c = 'blue',label='Cluster 1') # 标准
plt.scatter(X[y_kmeans == 2,0],X[y_kmeans == 2,1], s=100, c = 'green',label='Cluster 2') # 目标客户
plt.scatter(X[y_kmeans == 3,0],X[y_kmeans == 3,1], s=100, c = 'cyan',label='Cluster 3') # 非理性,小心
plt.scatter(X[y_kmeans == 4,0],X[y_kmeans == 4,1], s=100, c = 'magenta',label='Cluster 4')# 消费敏感用户
plt.scatter(kmeans.cluster_centers_[:,0],kmeans.cluster_centers_[:,1], s=300, c = 'yellow',label='Centroids')
plt.title(u'顾客群组')
plt.xlabel(u'年收入')
plt.ylabel(u'购物指数')
plt.legend() # 标签显示
plt.show()