• 算法特征:
    ①. 高斯分布作为基函数; ②. 多个高斯分布进行凸组合; ③. 极大似然法估计概率密度.
  • 算法推导:
    GMM概率密度形式如下:
    \begin{equation}
    p(x) = \sum_{k=1}^{K}\pi_kN(x|\mu_k, \Sigma_k)
    \label{eq_1}
    \end{equation}
    其中, $\pi_k$、$\mu_k$、$\Sigma_k$分别表示第$k$个高斯分布的权重、均值及协方差矩阵, 且$\sum\limits_{k=1}^{K}\pi_k = 1, \forall \pi_k \geq 0$.
    令样本集合为$\{x^{(1)}, x^{(2)}, \cdots, x^{(n)}\}$, 本文拟采用EM(Expectation-Maximization)算法求解上述优化变量$\{\pi_k, \mu_k, \Sigma_k\}_{k=1\sim K}$.
    $step1$:
    随机初始化$\{\pi_k, \mu_k, \Sigma_k\}_{k=1\sim K}$.
    $step2 \sim E\ step$:
    计算第$i$个样本落在第$k$个高斯的概率:
    \begin{equation}
    \gamma_k^{(i)} = \frac{\pi_kN(x^{(i)}|\mu_k, \Sigma_k)}{\sum\limits_{k=1}^{K} \pi_k N(x^{(i)}|\mu_k, \Sigma_k)}
    \label{eq_2}
    \end{equation}
    $step3 \sim M\ step$:
    计算第$k$个高斯的样本数:
    \begin{equation}
    N_k = \sum_{i=1}^{n}\gamma_k^{(i)}
    \label{eq_3}
    \end{equation}
    更新第$k$个高斯的权重:
    \begin{equation}
    \pi_k = \frac{N_k}{N}
    \label{eq_4}
    \end{equation}
    更新第$k$个高斯的均值:
    \begin{equation}
    \mu_k = \frac{\sum\limits_{i=1}^{n}\gamma_k^{(i)}x^{(i)}}{N_k}
    \label{eq_5}
    \end{equation}
    更新第$k$个高斯的协方差矩阵:
    \begin{equation}
    \Sigma_k = \frac{\sum\limits_{i=1}^{n}\gamma_k^{(i)}(x^{(i)} - \mu_k)(x^{(i)} - \mu_k)^{\mathrm{T}}}{N_k}
    \label{eq_6}
    \end{equation}
    $step4$:
    回到$step2$, 直到优化变量$\{\pi_k, \mu_k, \Sigma_k\}_{k=1\sim K}$均收敛.
  • 代码实现:
    Part Ⅰ:
    现以如下数据集为例进行算法实施:
     1 # 生成聚类数据集
     2 
     3 import numpy
     4 from matplotlib import pyplot as plt
     5 
     6 
     7 numpy.random.seed(3)
     8 
     9 
    10 def generate_data(clusterNum):
    11     mu = [0, 0]
    12     sigma = [[0.03, 0], [0, 0.03]]
    13     data = numpy.random.multivariate_normal(mu, sigma, (1000, ))
    14     
    15     for idx in range(clusterNum  - 1):
    16         mu = numpy.random.uniform(-1, 1, (2, ))
    17         arr = numpy.random.uniform(0, 1, (2, 2))
    18         sigma = numpy.matmul(arr.T, arr) / 10
    19         tmpData = numpy.random.multivariate_normal(mu, sigma, (1000, ))
    20         data = numpy.vstack((data, tmpData))
    21     
    22     return data
    23 
    24 
    25 def plot_data(data):
    26     fig = plt.figure(figsize=(5, 3))
    27     ax1 = plt.subplot()
    28     
    29     ax1.scatter(data[:, 0], data[:, 1], s=1)
    30     ax1.set(xlim=[-2, 2], ylim=[-2, 2])
    31     # plt.show()
    32     fig.savefig("plot_data.png", dpi=100)
    33     plt.close()
    34 
    35         
    36 
    37 if __name__ == "__main__":
    38     X = generate_data(5)
    39     plot_data(X)
    View Code

相关文章: