【问题标题】:How to apply KMeans to get the centroid using dataframe with multiple features如何使用具有多个特征的数据框应用 KMeans 来获取质心
【发布时间】:2021-02-24 13:52:38
【问题描述】:

我正在关注这个详细的 KMeans 教程:https://github.com/python-engineer/MLfromscratch/blob/master/mlfromscratch/kmeans.py,它使用具有 2 个特征的数据集。

但我有一个包含 5 个特征(列)的数据框,因此我没有在教程中使用 def euclidean_distance(x1, x2): 函数,而是按如下方式计算欧几里得距离。

def euclidean_distance(df):
    n = df.shape[1]
    distance_matrix = np.zeros((n,n))
    for i in range(n):
        for j in range(n):
            distance_matrix[i,j] = np.sqrt(np.sum((df.iloc[:,i] - df.iloc[:,j])**2))
    return distance_matrix

接下来我想实现教程中计算质心的部分,如下所示;

def _closest_centroid(self, sample, centroids):
    distances = [euclidean_distance(sample, point) for point in centroids]

由于我的 def euclidean_distance(df): 函数只需要 1 个参数 df,我怎样才能最好地实现它以获得质心?

我的示例数据集,df如下:

col1,col2,col3,col4,col5
0.54,0.68,0.46,0.98,-2.14
0.52,0.44,0.19,0.29,30.44
1.27,1.15,1.32,0.60,-161.63
0.88,0.79,0.63,0.58,-49.52
1.39,1.15,1.32,0.41,-188.52
0.86,0.80,0.65,0.65,-45.27

[添加:plot() 函数]

您包含的绘图函数给出了一个错误TypeError: object of type 'itertools.combinations' has no len(),我通过将len(combinations) 更改为len(list(combinations)) 解决了这个问题。但是输出 不是散点图。知道我需要在这里解决什么吗?

【问题讨论】:

  • 教程中的欧几里得距离函数是为数组定义的,因此空间的维度无关紧要。这意味着您不需要编写自己的函数。
  • 本教程中的函数适用于具有任意数量特征的两个数组(我在之前的评论中所说的维度)。它从数据集的形状中推断出特征的数量。运行教程中的代码时,您究竟在哪里遇到错误?
  • 第 81 行,在 _closest_centroid distances = [euclidean_distance(sample, point) for point in centroids] TypeError: euclidean_distance() 只需要 1 个参数(给定 2 个)
  • repo 中的函数不会抛出该错误,因为它确实需要 2 个参数。你确定你在用那个吗?
  • 当我使用教程函数时,line 17: y_pred = k.predict(X) 上的 kmeans_test.py 抛出 ' ValueError:无法识别的标记样式 [13.15717]' 指向 kmeans.py 文件上的 第 35 行93 行。如前所述,我更改了 make_blobs() 函数来满足我的 31 行和 5 列(功能)的数据框,如下所示。否则,教程代码无需任何修改即可正常运行。 data = pd.read_csv('df.csv')X = np.array(data)print(X.shape)clusters = 5k = KMeans(K=clusters, max_iters=150, plot_steps=True)y_pred = k.predict(X)k.plot()

标签: python dataframe k-means euclidean-distance


【解决方案1】:

读取数据并对其进行聚类不应引发任何错误,即使您增加数据集中的特征数量也是如此。 事实上,当你重新定义 euclidean_distance 函数时,你只会在这部分代码中得到错误。

此答案解决了您得到的绘图功能的实际错误。

   def plot(self):
      fig, ax = plt.subplots(figsize=(12, 8))

       for i, index in enumerate(self.clusters):
           point = self.X[index].T
           ax.scatter(*point)

获取给定集群中的所有点并尝试制作散点图。

ax.scatter(*point) 中的星号表示该点已解包。

这里的隐含假设(这就是为什么这可能很难发现)是point 应该是二维的。然后,各个部分被解释为要绘制的 x,y 值。

但是由于您有 5 个特征,所以点是 5 维的。

看着the docs of ax.scatter

matplotlib.axes.Axes.scatter
Axes.scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None,
verts=<deprecated parameter>, edgecolors=None, *, plotnonfinite=False,
data=None, **kwargs)

所以,ax.scatter 采用的前几个参数(除了 self)是:

x 
y
s (i.e. the markersize)
c (i.e. the color)
marker (i.e. the markerstyle)

前四个,即 x,y, s anc c 允许浮点数,但您的数据集是 5 维的,因此第五个特征被解释为标记,它需要一个 MarkerStyle。因为它得到一个浮点数,所以它会抛出错误。

做什么:

一次只查看 2 或 3 个维度,或使用降维(例如主成分分析)将数据投影到较低维度的空间。

对于第一个选项,您可以在 KMeans 类中重新定义 plot 方法:

def plot(self):
    

    import itertools
    combinations = itertools.combinations(range(self.K), 2) # generate all combinations of features
    
    fig, axes = plt.subplots(figsize=(12, 8), nrows=len(combinations), ncols=1) # initialise one subplot for each feature combination

    for (x,y), ax in zip(combinations, axes.ravel()): # loop through combinations and subpltos
        
        
        for i, index in enumerate(self.clusters):
            point = self.X[index].T
            
            # only get the coordinates for this combination:
            px, py = point[x], point[y]
            ax.scatter(px, py)

        for point in self.centroids:
            
            # only get the coordinates for this combination:
            px, py = point[x], point[y]
            
            ax.scatter(px, py, marker="x", color='black', linewidth=2)

        ax.set_title('feature {} vs feature {}'.format(x,y))
    plt.show()

【讨论】:

  • 非常感谢@warped。我现在可以使用原始教程使用 2 维进行聚类(但是,上面的 plot(self) 函数会执行,但只绘制黑色矩形而不是聚类)。我还尝试使用 PCA 将 5 个特征减少到 2 个维度,并且能够使集群正常工作。
  • @Gee 你可以使用 ax.plot 代替 ax.scatter
  • @Gee 您可以将其编辑到您的问题中或提出一个新问题吗?在评论中阅读这有点乏味
  • 知道如何获得您上面建议的 plot() 方法来绘制集群吗?我已经尝试了几次,但它没有绘制集群。
猜你喜欢
  • 2019-11-20
  • 2021-10-01
  • 1970-01-01
  • 2020-12-13
  • 2021-06-13
  • 2023-01-05
  • 2014-10-29
  • 2020-12-07
  • 1970-01-01
相关资源
最近更新 更多