【问题标题】:How to create one seaborn plot consisting of kx2 scatterplots with shared legend that has union of all classes如何创建一个包含 kx2 散点图的 seaborn 图,该散点图具有所有类的联合图例
【发布时间】:2020-01-27 15:01:50
【问题描述】:

我有几个散点图,每个散点图都有不同的类。我想将它们全部粘贴在 kx2 网格中,并在包含所有当前类的一侧带有一个图例,例如从单个图中删除图例。

我该怎么做?

这是 2x2 测试的 4 个图

from  matplotlib.lines import  Line2D
import pandas as pd
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
from  matplotlib.lines import  Line2D

df1 = pd.DataFrame({
    "class":["a", "b", "e"],
    "time":[1,2,3],
    "score":[10, 20, 30]
})

df2 = pd.DataFrame({
    "class":["a", "c", "d"],
    "time":[0,5,10],
    "score":[5, 25, 30]
})

df3 = pd.DataFrame({
    "class":["a", "b", "c", "d", "e"],
    "time":[0,5,10,30,50],
    "score":[5, 25, 30, 40, 100]
})

df4 = pd.DataFrame({
    "class":["a", "e"],
    "time":[1,2],
    "score":[10,25]
})

def get_palette():
  pal =  {
      'a': "#4C72B0", 
      'b': "#55A868", 
      'c': "#C44E52", 
      'd': "#8172B2", 
      'e': "#CCB974", 
  }
  return pal

def get_markers():
  mark = {
      'a': Line2D.filled_markers[0], 
      'b': Line2D.filled_markers[5], 
      'c': Line2D.filled_markers[6], 
      'd': Line2D.filled_markers[7],  
      'e': Line2D.filled_markers[8], 
  }
  return mark

def get_scatterplot(source, ds_name):
  scatter = sns.scatterplot(palette=get_palette(), markers=get_markers(), 
                            edgecolor='black', alpha=0.6, x="score", y="time",
                            hue="class", style="class", s=150, 
                            data=source).set_title(ds_name)
  return scatter

scatter_df1 = get_scatterplot(df1, "df1")
plt.show()

scatter_df2 = get_scatterplot(df2, "df2")
plt.show()

scatter_df3 = get_scatterplot(df3, "df3")
plt.show()

scatter_df4 = get_scatterplot(df4, "df4")
plt.show()

这是我根据 Stack 上的其他一些响应尝试做的事情

fig, axs = plt.subplots(ncols=2, nrows=2)
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df1, ax=axs[0]).set_title("ds1")
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df2, ax=axs[1]).set_title("ds2")
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df3, ax=axs[2]).set_title("ds3")
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df4, ax=axs[3]).set_title("ds4")

但它出错了,不知道为什么......

AttributeError: 'numpy.ndarray' object has no attribute 'scatter'

【问题讨论】:

  • 您可以使用自定义 figlegend 来获取整个图的单个图例,但您必须处理类重复项(即创建一个类名数组并使用 np.unique) .这足以满足您的目的吗?
  • 我对类重复有点困惑......你能回答一个答案,以便澄清事情并且我可以接受吗?我对此真的很陌生,所以这个提示没有多大帮助:(谢谢!
  • @WilliamMiller 我也添加了我的尝试,但我不知道它是接近正确还是完全错过
  • 您在编辑中遇到的错误是由于您从 axs 访问轴的方式。 fig, axs = plt.subplots(nrows, ncols) 返回一个由axes 实例组成的数组,其形状为(nrows, ncols),因此您需要执行axs[i, j] 来检索单个axes 实例(请参阅this answer

标签: python pandas matplotlib seaborn


【解决方案1】:

您可以使用matplotlib.pyplot.figlegend 为图形制作单个图例。在不传递任何参数的情况下,这将从“每个轴上的现有艺术家。”创建一个图例。如果您想自定义它,您可以直接提供图例句柄和标签。

由于您明确指定每个“类”的颜色,因此很容易编写自定义图例:

pal = get_palette()
handles = [Line2D([0], [0], color=c) for l, c in pal.items()]
labels = [l for l in pal]
plt.figlegend(handles=handles, labels=labels, loc='best')
plt.show()

应该做的伎俩。使用plt.subplots(nrows=2, ncols=2) 和问题代码,这将为您提供一个看起来像这样的图例

请注意,这适用于任何配置中的任意数量的类和任意数量的子图,前提是类及其对应的颜色都在pal 中定义,否则需要采用更高级的方法。

【讨论】:

    【解决方案2】:

    要解决上一个错误,您需要使用行/列索引以矩阵样式传递ax,因为您使用nrowncol 指定了子图布局:

    ...
    
    fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(12,8))
    
    sns.scatterplot(..., ax=axs[0,0]).set_title("ds1")
    sns.scatterplot(..., ax=axs[0,1]).set_title("ds2")
    sns.scatterplot(..., ax=axs[1,0]).set_title("ds3")
    sns.scatterplot(..., ax=axs[1,1]).set_title("ds4")
    
    plt.tight_layout()
    plt.show()
    


    要解决共享图例甚至共享轴的所需结果,请考虑将所有数据帧编译为一个并使用seaborn.FacetGrid 运行绘图。需要立即进行的一项更改是标记函数,它需要一个列表而不是 dict。 ...

    def get_markers_list():
      mark = [
          Line2D.filled_markers[0], 
          Line2D.filled_markers[5], 
          Line2D.filled_markers[6], 
          Line2D.filled_markers[7],  
          Line2D.filled_markers[8], 
      ]
      return mark
    
    # COMPILE ALL DFs INTO ONE
    master_df = pd.concat([df1.assign(grp="ds1"),
                           df2.assign(grp="ds2"),
                           df3.assign(grp="ds3"),
                           df4.assign(grp="ds4")])
    
    # RUN FACET GRID
    g = sns.FacetGrid(master_df, col="grp", hue="class", col_wrap=2, 
                      aspect=1.5, palette=get_palette(),
                      hue_order=list('abcde'),
                      hue_kws=dict(marker=get_markers_list()))
    
    g = (g.map(sns.scatterplot, "score", "time", 
               edgecolor='black', alpha=0.6, s=150)
          .add_legend())
    
    plt.show()
    

    【讨论】:

      猜你喜欢
      • 2020-12-04
      • 1970-01-01
      • 2019-04-15
      • 2016-01-30
      • 2018-04-04
      • 1970-01-01
      • 2017-05-09
      • 2020-11-07
      • 2020-08-10
      相关资源
      最近更新 更多