机器学习算法与Python实践(13) - 均值漂移聚类 Mean-Shift Clustering

其实相信很多人多少都已经接触过这种聚类的方法,这篇文章也是参考别人的做的总结,也算是加深自己印象的一个笔记。

一、算法概述

Mean Shift算法,又称为均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后来由Yizong Cheng对其进行扩充,主要提出了两点的改进:

  • 定义了核函数;
  • 增加了权重系数。

核函数的定义使得偏移值对偏移向量的贡献随之样本与被偏移点的距离的不同而不同。权重系数使得不同样本的权重不同。Mean Shift算法在聚类,图像平滑、分割以及视频跟踪等方面有广泛的应用。

二、算法核心原理

2.1 核函数

在Mean Shift算法中引入核函数的目的是使得随着样本与被偏移点的距离不同,其偏移量对均值偏移向量的贡献也不同。核函数是机器学习中常用的一种方式。核函数的定义如下所示:

XX表示一个dd维的欧式空间,xx是该空间中的一个点x=x1,x2,x3,xdx={x_1,x_2,x_3⋯,x_d}其中,xx的模x2=xxT‖x‖^2=xx^TRR实数域,如果一个函数K:XRK:X→R存在一个剖面函数k:[0,]Rk:[0,∞]→R,即
K(x)=k(x2)K(x)=k(‖x‖^2)
并且满足:
(1) kk是非负的
(2) kk是非增的
(3) kk是分段连续的
那么,函数K(x)K(x)就称为核函数。

常用的核函数有高斯核函数。高斯核函数如下:

N(x)=12πhex22h2N(x)=\frac{1}{\sqrt{2\pi}h}e^{-\frac{x^2}{2h^2}}

其中,hh称为带宽(bandwidth),不同带宽的核函数如下图所示:

机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)

上图的画图脚本如下所示:

import matplotlib.pyplot as plt
import math

def cal_Gaussian(x, h=1):
    molecule = x * x
    denominator = 2 * h * h
    left = 1 / (math.sqrt(2 * math.pi) * h)
    return left * math.exp(-molecule / denominator)

x = []

for i in xrange(-40,40):
    x.append(i * 0.5);

score_1 = []
score_2 = []
score_3 = []
score_4 = []

for i in x:
    score_1.append(cal_Gaussian(i,1))
    score_2.append(cal_Gaussian(i,2))
    score_3.append(cal_Gaussian(i,3))
    score_4.append(cal_Gaussian(i,4))

plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")

plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()

2.2 Mean Shift 算法核心思想

2.21 基本原理

对于Mean Shift算法,是一个迭代的步骤,即先算出当前点的偏移均值,将该点移动到此偏移均值,然后以此为新的起始点,继续移动,直到满足最终的条件。此过程可由下图的过程进行说明(图片来自参考文献3):

  • 步骤1:在指定的区域内计算偏移均值(如下图的黄色的圈)

机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)

  • 步骤2:移动该点到偏移均值点处
    机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)

  • 步骤3: 重复上述的过程(计算新的偏移均值,移动)
    机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)


机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)


机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)


机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)

  • 步骤4:满足了最终的条件,即退出
    机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)

从上述过程可以看出,在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。

2.22 基本Mean Shift向量形式

对于给定的dd维空间RdR^d中的nn个样本点xix_ii=1,2,...,ni=1,2,...,n,则对于xx点,其mean shift向量的基本形式为:

Mh(x)=1kxiSh(xix)M_h(x)=\frac{1}{k}\sum_{x_i\in{S_h}}(x_i-x)

其中ShS_h指的是一个半径为hh的高维球区域,如上图中的蓝色圆形区域。ShS_h定义为:

Sh(x)=(y(yx)(yx)Th2)S_h(x)=(y∣(y−x)(y−x)^T⩽h^2)

这样的一种基本的Mean Shift形式存在一个问题:在ShS_h的区域内,每一个点对x的贡献是一样的。而实际上,这种贡献与x到每一个点之间的距离是相关的。同时,对于每一个样本,其重要程度也是不一样的。

官网上给的例子:

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs

# #############################################################################
# Generate sample data  造用于聚类的数据
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# #############################################################################
# Compute clustering with MeanShift

# The following bandwidth can be automatically detected using
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)  # 训练模型
labels = ms.labels_  # 所有点的的labels
cluster_centers = ms.cluster_centers_  # 聚类得到的中心点

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    print(cluster_center)
    plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
    plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
             markeredgecolor='k', markersize=14)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()

未完待续 …

相关文章: