【问题标题】:Plot correlation matrix using pandas使用 pandas 绘制相关矩阵
【发布时间】:2015-06-08 13:40:16
【问题描述】:

我有一个包含大量特征的数据集,因此分析相关矩阵变得非常困难。我想绘制一个相关矩阵,我们使用 pandas 库中的 dataframe.corr() 函数获得该矩阵。 pandas 库是否提供了任何内置函数来绘制此矩阵?

【问题讨论】:

标签: python pandas matplotlib data-visualization information-visualization


【解决方案1】:

我认为有很多很好的答案,但我将此答案添加给那些需要处理特定列并显示不同情节的人。

import numpy as np
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt

rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(18, 18))
df= df.iloc[: , [3,4,5,6,7,8,9,10,11,12,13,14,17]].copy()
corr = df.corr()
plt.figure(figsize=(11,8))
sns.heatmap(corr, cmap="Greens",annot=True)
plt.show()

【讨论】:

    【解决方案2】:
    corrmatrix = df.corr()
    corrmatrix *= np.tri(*corrmatrix.values.shape, k=-1).T
    corrmatrix = corrmatrix.stack().sort_values(ascending = False).reset_index()
    corrmatrix.columns = ['Признак 1', 'Признак 2', 'Корреляция']
    corrmatrix[(corrmatrix['Корреляция'] >= 0.7) + (corrmatrix['Корреляция'] <= -0.7)]
    drop_columns = corrmatrix[(corrmatrix['Корреляция'] >= 0.82) + (corrmatrix['Корреляция'] <= -0.7)]['Признак 2']
    df.drop(drop_columns, axis=1, inplace=True)
    corrmatrix[(corrmatrix['Корреляция'] >= 0.7) + (corrmatrix['Корреляция'] <= -0.7)]
    

    【讨论】:

    • 您的答案可以通过额外的支持信息得到改进。请edit 添加更多详细信息,例如引用或文档,以便其他人可以确认您的答案是正确的。你可以找到更多关于如何写好答案的信息in the help center
    • 为您的代码添加解释,解释为什么它比公认的答案更好,并确保在代码中使用英语。
    【解决方案3】:

    试试这个函数,它还显示相关矩阵的变量名称:

    def plot_corr(df,size=10):
        """Function plots a graphical correlation matrix for each pair of columns in the dataframe.
    
        Input:
            df: pandas DataFrame
            size: vertical and horizontal size of the plot
        """
    
        corr = df.corr()
        fig, ax = plt.subplots(figsize=(size, size))
        ax.matshow(corr)
        plt.xticks(range(len(corr.columns)), corr.columns)
        plt.yticks(range(len(corr.columns)), corr.columns)
    

    【讨论】:

    • plt.xticks(range(len(corr.columns)), corr.columns, rotation='vertical') 如果你想在 x 轴上垂直排列列名
    • 另一个图形化的东西,但添加 plt.tight_layout() 也可能对长列名有用。
    【解决方案4】:

    如果您的主要目标是可视化相关矩阵,而不是创建绘图本身,那么方便的 pandas styling options 是一个可行的内置解决方案:

    import pandas as pd
    import numpy as np
    
    rs = np.random.RandomState(0)
    df = pd.DataFrame(rs.rand(10, 10))
    corr = df.corr()
    corr.style.background_gradient(cmap='coolwarm')
    # 'RdBu_r', 'BrBG_r', & PuOr_r are other good diverging colormaps
    

    请注意,这需要在支持呈现 HTML 的后端中,例如 JupyterLab Notebook。


    造型

    您可以轻松限制数字精度:

    corr.style.background_gradient(cmap='coolwarm').set_precision(2)
    

    如果您更喜欢没有注释的矩阵,或者完全摆脱数字:

    corr.style.background_gradient(cmap='coolwarm').set_properties(**{'font-size': '0pt'})
    

    样式文档还包含更高级样式的说明,例如如何更改鼠标指针悬停在单元格上的显示。


    时间比较

    在我的测试中,style.background_gradient() 在 10x10 矩阵中比 plt.matshow() 快 4 倍,比 sns.heatmap() 快 120 倍。不幸的是,它的扩展性不如 plt.matshow():对于 100x100 矩阵,两者的时间差不多,而对于 1000x1000 矩阵,plt.matshow() 的速度要快 10 倍。


    保存

    有几种可能的方法来保存程式化的数据框:

    • 通过附加 render() 方法返回 HTML,然后将输出写入文件。
    • 通过附加to_excel() 方法保存为具有条件格式的.xslx 文件。
    • Combine with imgkit to save a bitmap
    • 截屏(就像我在这里所做的那样)。

    标准化整个矩阵的颜色(pandas >= 0.24)

    通过设置axis=None,现在可以基于整个矩阵而不是每列或每行来计算颜色:

    corr.style.background_gradient(cmap='coolwarm', axis=None)
    


    单角热图

    由于很多人都在阅读此答案,我想我会添加一个提示,说明如何仅显示相关矩阵的一个角。我觉得这更容易阅读,因为它删除了多余的信息。

    # Fill diagonal and upper half with NaNs
    mask = np.zeros_like(corr, dtype=bool)
    mask[np.triu_indices_from(mask)] = True
    corr[mask] = np.nan
    (corr
     .style
     .background_gradient(cmap='coolwarm', axis=None, vmin=-1, vmax=1)
     .highlight_null(null_color='#f1f1f1')  # Color NaNs grey
     .set_precision(2))
    

    【讨论】:

    • 如果有办法导出为图片,那就太好了!
    • 谢谢!你肯定需要一个不同的调色板import seaborn as sns corr = df.corr() cm = sns.light_palette("green", as_cmap=True) cm = sns.diverging_palette(220, 20, sep=20, as_cmap=True) corr.style.background_gradient(cmap=cm).set_precision(2)
    • @stallingOne 好点,我不应该在示例中包含负值,我以后可能会更改它。仅供阅读本文的人参考,您不需要使用 seaborn 创建自定义发散 cmap(尽管上面评论中的那个看起来很漂亮),您也可以使用 matplotlib 中的内置发散 cmap,例如corr.style.background_gradient(cmap='coolwarm')。目前无法将 cmap 置于特定值的中心,这对于不同的 cmap 可能是一个好主意。
    • @rovyko 你在 pandas >=0.24.0 吗?
    • 这些图在视觉上很棒,但@Kristada673 的问题非常相关,你将如何导出它们?
    【解决方案5】:

    请检查以下可读代码

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    plt.figure(figsize=(36, 26))
    heatmap = sns.heatmap(df.corr(), vmin=-1, vmax=1, annot=True)
    heatmap.set_title('Correlation Heatmap', fontdict={'fontsize':12}, pad=12)```
    
      [1]: https://i.stack.imgur.com/I5SeR.png
    

    【讨论】:

      【解决方案6】:

      您可以使用 seaborn 的 heatmap() 来查看黑白不同特征的相关性:

      import matplot.pyplot as plt
      import seaborn as sns
      
      co_matrics=dataframe.corr()
      plot.figure(figsize=(15,20))
      sns.heatmap(co_matrix, square=True, cbar_kws={"shrink": .5})
      

      【讨论】:

        【解决方案7】:

        您可以从matplotlib 使用pyplot.matshow()

        import matplotlib.pyplot as plt
        
        plt.matshow(dataframe.corr())
        plt.show()
        

        编辑:

        在 cmets 中有一个关于如何更改轴刻度标签的请求。这是一个在更大的图形尺寸上绘制的豪华版本,具有与数据框匹配的轴标签,以及用于解释色标的颜色条图例。

        我将介绍如何调整标签的大小和旋转,并且我使用了一个图形比例,使颜色条和主图形的高度相同。


        编辑 2: 由于 df.corr() 方法会忽略非数字列,因此在定义 x 和 y 标签时应使用.select_dtypes(['number']) 以避免标签发生不必要的移位(包含在下面的代码中)。

        f = plt.figure(figsize=(19, 15))
        plt.matshow(df.corr(), fignum=f.number)
        plt.xticks(range(df.select_dtypes(['number']).shape[1]), df.select_dtypes(['number']).columns, fontsize=14, rotation=45)
        plt.yticks(range(df.select_dtypes(['number']).shape[1]), df.select_dtypes(['number']).columns, fontsize=14)
        cb = plt.colorbar()
        cb.ax.tick_params(labelsize=14)
        plt.title('Correlation Matrix', fontsize=16);
        

        【讨论】:

        • 我一定错过了什么:AttributeError: 'module' object has no attribute 'matshow'
        • @TomRussell 你做了import matplotlib.pyplot as plt吗?
        • 你知道如何在绘图上显示实际的列名吗?
        • @Cecilia 我已经通过将 rotation 参数更改为 90 解决了这个问题
        • 如果列名比那些长,x 标签看起来会有点偏离,在我的情况下,它看起来很混乱,因为它们看起来移动了一个刻度。将ha="left" 添加到plt.xticks 调用解决了这个问题,以防万一有人也有它:) 在stackoverflow.com/questions/28615887/… 中描述
        【解决方案8】:

        惊讶地发现没有人提到更强大、更具交互性和更易于使用的替代方案。

        A)你可以使用情节:

        1. 只需两行即可:

        2. 交互性,

        3. 平滑的刻度,

        4. 颜色基于整个数据框而不是单个列,

        5. 轴上的列名和行索引,

        6. 放大,

        7. 平移,

        8. 内置一键保存为PNG格式,

        9. 自动缩放,

        10. 悬停比较,

        11. 显示值的气泡,因此热图看起来仍然不错,您可以看到 任何你想要的值:

        import plotly.express as px
        fig = px.imshow(df.corr())
        fig.show()
        

        B) 你也可以使用散景:

        所有相同的功能都有一点麻烦。但是,如果您不想选择加入 plotly 并且仍然想要所有这些东西,那么仍然值得:

        from bokeh.plotting import figure, show, output_notebook
        from bokeh.models import ColumnDataSource, LinearColorMapper
        from bokeh.transform import transform
        output_notebook()
        colors = ['#d7191c', '#fdae61', '#ffffbf', '#a6d96a', '#1a9641']
        TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
        data = df.corr().stack().rename("value").reset_index()
        p = figure(x_range=list(df.columns), y_range=list(df.index), tools=TOOLS, toolbar_location='below',
                   tooltips=[('Row, Column', '@level_0 x @level_1'), ('value', '@value')], height = 500, width = 500)
        
        p.rect(x="level_1", y="level_0", width=1, height=1,
               source=data,
               fill_color={'field': 'value', 'transform': LinearColorMapper(palette=colors, low=data.value.min(), high=data.value.max())},
               line_color=None)
        color_bar = ColorBar(color_mapper=LinearColorMapper(palette=colors, low=data.value.min(), high=data.value.max()), major_label_text_font_size="7px",
                             ticker=BasicTicker(desired_num_ticks=len(colors)),
                             formatter=PrintfTickFormatter(format="%f"),
                             label_standoff=6, border_line_color=None, location=(0, 0))
        p.add_layout(color_bar, 'right')
        
        show(p)
        

        【讨论】:

          【解决方案9】:

          形成相关矩阵,在我的例子中 zdf 是我需要执行相关矩阵的数据框。

          corrMatrix =zdf.corr()
          corrMatrix.to_csv('sm_zscaled_correlation_matrix.csv');
          html = corrMatrix.style.background_gradient(cmap='RdBu').set_precision(2).render()
          
          # Writing the output to a html file.
          with open('test.html', 'w') as f:
             print('<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-widthinitial-scale=1.0"><title>Document</title></head><style>table{word-break: break-all;}</style><body>' + html+'</body></html>', file=f)
          

          然后我们可以截图。或者将html转换成图片文件。

          【讨论】:

            【解决方案10】:

            与其他方法一起使用pairplot也很好,它可以为所有情况提供散点图-

            import pandas as pd
            import numpy as np
            import seaborn as sns
            rs = np.random.RandomState(0)
            df = pd.DataFrame(rs.rand(10, 10))
            sns.pairplot(df)
            

            【讨论】:

              【解决方案11】:

              为了完整起见,截至 2019 年底,我知道的最简单的解决方案是 seaborn,如果有人使用 Jupyter

              import seaborn as sns
              sns.heatmap(dataframe.corr())
              

              【讨论】:

                【解决方案12】:

                statmodels 图形还提供了一个很好的相关矩阵视图

                import statsmodels.api as sm
                import matplotlib.pyplot as plt
                
                corr = dataframe.corr()
                sm.graphics.plot_corr(corr, xnames=list(corr.columns))
                plt.show()
                

                【讨论】:

                  【解决方案13】:

                  如果您的数据框是df,您可以简单地使用:

                  import matplotlib.pyplot as plt
                  import seaborn as sns
                  
                  plt.figure(figsize=(15, 10))
                  sns.heatmap(df.corr(), annot=True)
                  

                  【讨论】:

                    【解决方案14】:

                    你可以使用 matplotlib 中的 imshow() 方法

                    import pandas as pd
                    import matplotlib.pyplot as plt
                    plt.style.use('ggplot')
                    
                    plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest')
                    plt.colorbar()
                    tick_marks = [i for i in range(len(X.columns))]
                    plt.xticks(tick_marks, X.columns, rotation='vertical')
                    plt.yticks(tick_marks, X.columns)
                    plt.show()
                    

                    【讨论】:

                      【解决方案15】:

                      您可以通过从 seaborn 绘制热图或从 pandas 绘制散布矩阵来观察特征之间的关系。

                      散点矩阵:

                      pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');
                      

                      如果您还想可视化每个特征的偏度 - 使用 seaborn pairplots。

                      sns.pairplot(dataframe)
                      

                      Sns 热图:

                      import seaborn as sns
                      
                      f, ax = pl.subplots(figsize=(10, 8))
                      corr = dataframe.corr()
                      sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True),
                                  square=True, ax=ax)
                      

                      输出将是特征的相关图。即见下面的例子。

                      杂货和洗涤剂之间的相关性很高。同样:

                      具有高相关性的产品:
                      1. 杂货和洗涤剂。
                      具有中等相关性的产品:
                      1. 牛奶和杂货
                      2. 牛奶和洗涤剂_纸
                      低相关性产品:
                      1. 牛奶和熟食
                      2. 冷冻和新鲜。
                      3. 冷冻和熟食店。

                      从配对图:您可以从配对图或散点矩阵观察相同的关系集。但是从这些我们可以说数据是否是正态分布的。

                      注意:上图是取自数据的同一张图,用于绘制热图。

                      【讨论】:

                      • 我认为应该是 .plt 而不是 .pl(如果这是指 matplotlib)
                      • @ghukill 不一定。他本可以将其称为from matplotlib import pyplot as pl
                      • 如何在相关图中始终设置-1到+1之间的相关边界
                      【解决方案16】:

                      Seaborn 的热图版本:

                      import seaborn as sns
                      corr = dataframe.corr()
                      sns.heatmap(corr, 
                                  xticklabels=corr.columns.values,
                                  yticklabels=corr.columns.values)
                      

                      【讨论】:

                      • Seaborn 热图很漂亮,但在大型矩阵上表现不佳。 matplotlib 的 matshow 方法要快得多。
                      • Seaborn 可以根据列名自动推断刻度标签。
                      • 如果让 seaborn 自动推断 stackoverflow.com/questions/50754471/…,似乎并非总是显示所有刻度标签
                      • 最好还包括将颜色从 -1 标准化为 1,否则颜色将从最低相关(可以是任何地方)到最高相关(1,对角线)。跨度>
                      猜你喜欢
                      • 2018-12-05
                      • 1970-01-01
                      • 1970-01-01
                      • 1970-01-01
                      • 1970-01-01
                      • 2021-11-18
                      • 1970-01-01
                      • 2012-07-19
                      • 2015-08-29
                      相关资源
                      最近更新 更多