一、导入标准库

In [1]:
# Importing the libraries 导入库
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 使图像能够调整
%matplotlib notebook 
#中文字体显示  
plt.rc('font', family='SimHei', size=8)

二、导入数据

In [2]:
dataset = pd.read_csv('Mall_Customers.csv') # 寻找目标用户,根据年收入和购物指数进行聚类分析
dataset
Out[2]:
  CustomerID Genre Age Annual Income (k$) Spending Score (1-100)
0 1 Male 19 15 39
1 2 Male 21 15 81
2 3 Female 20 16 6
3 4 Female 23 16 77
4 5 Female 31 17 40
5 6 Female 22 17 76
6 7 Female 35 18 6
7 8 Female 23 18 94
8 9 Male 64 19 3
9 10 Female 30 19 72
10 11 Male 67 19 14
11 12 Female 35 19 99
12 13 Female 58 20 15
13 14 Female 24 20 77
14 15 Male 37 20 13
15 16 Male 22 20 79
16 17 Female 35 21 35
17 18 Male 20 21 66
18 19 Male 52 23 29
19 20 Female 35 23 98
20 21 Male 35 24 35
21 22 Male 25 24 73
22 23 Female 46 25 5
23 24 Male 31 25 73
24 25 Female 54 28 14
25 26 Male 29 28 82
26 27 Female 45 28 32
27 28 Male 35 28 61
28 29 Female 40 29 31
29 30 Female 23 29 87
... ... ... ... ... ...
170 171 Male 40 87 13
171 172 Male 28 87 75
172 173 Male 36 87 10
173 174 Male 36 87 92
174 175 Female 52 88 13
175 176 Female 30 88 86
176 177 Male 58 88 15
177 178 Male 27 88 69
178 179 Male 59 93 14
179 180 Male 35 93 90
180 181 Female 37 97 32
181 182 Female 32 97 86
182 183 Male 46 98 15
183 184 Female 29 98 88
184 185 Female 41 99 39
185 186 Male 30 99 97
186 187 Female 54 101 24
187 188 Male 28 101 68
188 189 Female 41 103 17
189 190 Female 36 103 85
190 191 Female 34 103 23
191 192 Female 32 103 69
192 193 Male 33 113 8
193 194 Female 38 113 91
194 195 Female 47 120 16
195 196 Female 35 120 79
196 197 Female 45 126 28
197 198 Male 32 126 74
198 199 Male 32 137 18
199 200 Male 30 137 83

200 rows × 5 columns

In [4]:
X = dataset.iloc[:, 3:5].values # 顾客年收入,顾客购物指数
X
Out[4]:
array([[ 15,  39],
       [ 15,  81],
       [ 16,   6],
       [ 16,  77],
       [ 17,  40],
       [ 17,  76],
       [ 18,   6],
       [ 18,  94],
       [ 19,   3],
       [ 19,  72],
       [ 19,  14],
       [ 19,  99],
       [ 20,  15],
       [ 20,  77],
       [ 20,  13],
       [ 20,  79],
       [ 21,  35],
       [ 21,  66],
       [ 23,  29],
       [ 23,  98],
       [ 24,  35],
       [ 24,  73],
       [ 25,   5],
       [ 25,  73],
       [ 28,  14],
       [ 28,  82],
       [ 28,  32],
       [ 28,  61],
       [ 29,  31],
       [ 29,  87],
       [ 30,   4],
       [ 30,  73],
       [ 33,   4],
       [ 33,  92],
       [ 33,  14],
       [ 33,  81],
       [ 34,  17],
       [ 34,  73],
       [ 37,  26],
       [ 37,  75],
       [ 38,  35],
       [ 38,  92],
       [ 39,  36],
       [ 39,  61],
       [ 39,  28],
       [ 39,  65],
       [ 40,  55],
       [ 40,  47],
       [ 40,  42],
       [ 40,  42],
       [ 42,  52],
       [ 42,  60],
       [ 43,  54],
       [ 43,  60],
       [ 43,  45],
       [ 43,  41],
       [ 44,  50],
       [ 44,  46],
       [ 46,  51],
       [ 46,  46],
       [ 46,  56],
       [ 46,  55],
       [ 47,  52],
       [ 47,  59],
       [ 48,  51],
       [ 48,  59],
       [ 48,  50],
       [ 48,  48],
       [ 48,  59],
       [ 48,  47],
       [ 49,  55],
       [ 49,  42],
       [ 50,  49],
       [ 50,  56],
       [ 54,  47],
       [ 54,  54],
       [ 54,  53],
       [ 54,  48],
       [ 54,  52],
       [ 54,  42],
       [ 54,  51],
       [ 54,  55],
       [ 54,  41],
       [ 54,  44],
       [ 54,  57],
       [ 54,  46],
       [ 57,  58],
       [ 57,  55],
       [ 58,  60],
       [ 58,  46],
       [ 59,  55],
       [ 59,  41],
       [ 60,  49],
       [ 60,  40],
       [ 60,  42],
       [ 60,  52],
       [ 60,  47],
       [ 60,  50],
       [ 61,  42],
       [ 61,  49],
       [ 62,  41],
       [ 62,  48],
       [ 62,  59],
       [ 62,  55],
       [ 62,  56],
       [ 62,  42],
       [ 63,  50],
       [ 63,  46],
       [ 63,  43],
       [ 63,  48],
       [ 63,  52],
       [ 63,  54],
       [ 64,  42],
       [ 64,  46],
       [ 65,  48],
       [ 65,  50],
       [ 65,  43],
       [ 65,  59],
       [ 67,  43],
       [ 67,  57],
       [ 67,  56],
       [ 67,  40],
       [ 69,  58],
       [ 69,  91],
       [ 70,  29],
       [ 70,  77],
       [ 71,  35],
       [ 71,  95],
       [ 71,  11],
       [ 71,  75],
       [ 71,   9],
       [ 71,  75],
       [ 72,  34],
       [ 72,  71],
       [ 73,   5],
       [ 73,  88],
       [ 73,   7],
       [ 73,  73],
       [ 74,  10],
       [ 74,  72],
       [ 75,   5],
       [ 75,  93],
       [ 76,  40],
       [ 76,  87],
       [ 77,  12],
       [ 77,  97],
       [ 77,  36],
       [ 77,  74],
       [ 78,  22],
       [ 78,  90],
       [ 78,  17],
       [ 78,  88],
       [ 78,  20],
       [ 78,  76],
       [ 78,  16],
       [ 78,  89],
       [ 78,   1],
       [ 78,  78],
       [ 78,   1],
       [ 78,  73],
       [ 79,  35],
       [ 79,  83],
       [ 81,   5],
       [ 81,  93],
       [ 85,  26],
       [ 85,  75],
       [ 86,  20],
       [ 86,  95],
       [ 87,  27],
       [ 87,  63],
       [ 87,  13],
       [ 87,  75],
       [ 87,  10],
       [ 87,  92],
       [ 88,  13],
       [ 88,  86],
       [ 88,  15],
       [ 88,  69],
       [ 93,  14],
       [ 93,  90],
       [ 97,  32],
       [ 97,  86],
       [ 98,  15],
       [ 98,  88],
       [ 99,  39],
       [ 99,  97],
       [101,  24],
       [101,  68],
       [103,  17],
       [103,  85],
       [103,  23],
       [103,  69],
       [113,   8],
       [113,  91],
       [120,  16],
       [120,  79],
       [126,  28],
       [126,  74],
       [137,  18],
       [137,  83]], dtype=int64)

三、寻找最佳组数(手肘法:找到拐点)

In [12]:
from sklearn.cluster import KMeans
wcss = []
for i in  range(1,11):
    kmeans = KMeans(n_clusters = i,max_iter = 300,n_init = 10,init= 'k-means++',random_state = 0)
    kmeans.fit(X)
    wcss.append(kmeans.inertia_)    # 计算组间距离
plt.plot(range(1,11),wcss)
plt.title(u'手肘图像')
plt.xlabel(u'集群数')
plt.ylabel(u'组间距离')
plt.show()

机器学习之K平均算法聚类

从图中可以看出,最佳组数为5

四、K平均聚类算法训练

In [18]:
kmeans = KMeans(n_clusters = 5,max_iter = 300,n_init = 10,init= 'k-means++',random_state = 0) # init= 'k-means++' 避免了初始化陷阱
y_kmeans = kmeans.fit_predict(X)
print(y_kmeans ) # 打印分类结果
print(kmeans.cluster_centers_)# 打印聚类后的中心点
[4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4
 3 4 3 4 3 4 1 4 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 2 0 2 1 2 0 2 0 2 1 2 0 2 0 2 0 2 0 2 1 2 0 2 0 2
 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0
 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2]
[[ 88.2         17.11428571]
 [ 55.2962963   49.51851852]
 [ 86.53846154  82.12820513]
 [ 25.72727273  79.36363636]
 [ 26.30434783  20.91304348]]

五、可视化集群

In [15]:
plt.scatter(X[y_kmeans == 0,0],X[y_kmeans == 0,1], s=100, c = 'red',label='Cluster 0') # 理性
plt.scatter(X[y_kmeans == 1,0],X[y_kmeans == 1,1], s=100, c = 'blue',label='Cluster 1') # 标准
plt.scatter(X[y_kmeans == 2,0],X[y_kmeans == 2,1], s=100, c = 'green',label='Cluster 2') # 目标客户
plt.scatter(X[y_kmeans == 3,0],X[y_kmeans == 3,1], s=100, c = 'cyan',label='Cluster 3') # 非理性,小心
plt.scatter(X[y_kmeans == 4,0],X[y_kmeans == 4,1], s=100, c = 'magenta',label='Cluster 4')# 消费敏感用户

plt.scatter(kmeans.cluster_centers_[:,0],kmeans.cluster_centers_[:,1], s=300, c = 'yellow',label='Centroids')
plt.title(u'顾客群组')
plt.xlabel(u'年收入')
plt.ylabel(u'购物指数')
plt.legend() # 标签显示
plt.show()

机器学习之K平均算法聚类

六、项目地址

相关文章: