-
算法特征:
①. 高斯分布作为基函数; ②. 多个高斯分布进行凸组合; ③. 极大似然法估计概率密度. -
算法推导:
GMM概率密度形式如下:
\begin{equation}
p(x) = \sum_{k=1}^{K}\pi_kN(x|\mu_k, \Sigma_k)
\label{eq_1}
\end{equation}
其中, $\pi_k$、$\mu_k$、$\Sigma_k$分别表示第$k$个高斯分布的权重、均值及协方差矩阵, 且$\sum\limits_{k=1}^{K}\pi_k = 1, \forall \pi_k \geq 0$.
令样本集合为$\{x^{(1)}, x^{(2)}, \cdots, x^{(n)}\}$, 本文拟采用EM(Expectation-Maximization)算法求解上述优化变量$\{\pi_k, \mu_k, \Sigma_k\}_{k=1\sim K}$.
$step1$:
随机初始化$\{\pi_k, \mu_k, \Sigma_k\}_{k=1\sim K}$.
$step2 \sim E\ step$:
计算第$i$个样本落在第$k$个高斯的概率:
\begin{equation}
\gamma_k^{(i)} = \frac{\pi_kN(x^{(i)}|\mu_k, \Sigma_k)}{\sum\limits_{k=1}^{K} \pi_k N(x^{(i)}|\mu_k, \Sigma_k)}
\label{eq_2}
\end{equation}
$step3 \sim M\ step$:
计算第$k$个高斯的样本数:
\begin{equation}
N_k = \sum_{i=1}^{n}\gamma_k^{(i)}
\label{eq_3}
\end{equation}
更新第$k$个高斯的权重:
\begin{equation}
\pi_k = \frac{N_k}{N}
\label{eq_4}
\end{equation}
更新第$k$个高斯的均值:
\begin{equation}
\mu_k = \frac{\sum\limits_{i=1}^{n}\gamma_k^{(i)}x^{(i)}}{N_k}
\label{eq_5}
\end{equation}
更新第$k$个高斯的协方差矩阵:
\begin{equation}
\Sigma_k = \frac{\sum\limits_{i=1}^{n}\gamma_k^{(i)}(x^{(i)} - \mu_k)(x^{(i)} - \mu_k)^{\mathrm{T}}}{N_k}
\label{eq_6}
\end{equation}
$step4$:
回到$step2$, 直到优化变量$\{\pi_k, \mu_k, \Sigma_k\}_{k=1\sim K}$均收敛. -
代码实现:
Part Ⅰ:
现以如下数据集为例进行算法实施:View Code1 # 生成聚类数据集 2 3 import numpy 4 from matplotlib import pyplot as plt 5 6 7 numpy.random.seed(3) 8 9 10 def generate_data(clusterNum): 11 mu = [0, 0] 12 sigma = [[0.03, 0], [0, 0.03]] 13 data = numpy.random.multivariate_normal(mu, sigma, (1000, )) 14 15 for idx in range(clusterNum - 1): 16 mu = numpy.random.uniform(-1, 1, (2, )) 17 arr = numpy.random.uniform(0, 1, (2, 2)) 18 sigma = numpy.matmul(arr.T, arr) / 10 19 tmpData = numpy.random.multivariate_normal(mu, sigma, (1000, )) 20 data = numpy.vstack((data, tmpData)) 21 22 return data 23 24 25 def plot_data(data): 26 fig = plt.figure(figsize=(5, 3)) 27 ax1 = plt.subplot() 28 29 ax1.scatter(data[:, 0], data[:, 1], s=1) 30 ax1.set(xlim=[-2, 2], ylim=[-2, 2]) 31 # plt.show() 32 fig.savefig("plot_data.png", dpi=100) 33 plt.close() 34 35 36 37 if __name__ == "__main__": 38 X = generate_data(5) 39 plot_data(X)
相关文章: