【问题标题】:How to display legends in scatter plot in order to differentiate between the classes如何在散点图中显示图例以区分类
【发布时间】:2021-06-14 18:35:54
【问题描述】:

我正在处理来自 sklearn 的 iris 数据集。您可能知道 iris 数据集有 3 个类 ['setosa'、'versicolor'、'virginica']。我为这个数据集做了一个散点图。详情如下

from sklearn.datasets import load_iris
iris=load_iris()
Y_train=iris.target
X_train=iris.data
class_labels=iris.target_names
plt.scatter(X_train[:,0], X_train[:,1], c=Y_train)
plt.xlabel('attr1')
plt.ylabel('attr2')
plt.show()

我有散点图,您可以在其中看到黄色、绿色和紫色的点。我想知道哪个颜色点属于哪个类('setosa'、'versicolor'、'virginica')。我想显示图例,以便我知道哪种颜色代表哪个类

【问题讨论】:

    标签: python matplotlib scikit-learn scatter-plot iris-dataset


    【解决方案1】:

    在这种情况下,您可以通过循环遍历标签并使用与散点图相同的colormapnorm 来创建custom legend。默认情况下,使用'viridis' 颜色映射,以及将最小颜色值映射为 0 并将最大值映射为 1 的规范。

    import matplotlib.pyplot as plt
    from sklearn.datasets import load_iris
    
    iris = load_iris()
    Y_train = iris.target
    X_train = iris.data
    class_labels = iris.target_names
    cmap = plt.get_cmap('viridis')
    norm = plt.Normalize(Y_train.min(), Y_train.max())
    plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, cmap='viridis', norm=norm)
    handles = [plt.Line2D([0, 0], [0, 0], color=cmap(norm(i)), marker='o', linestyle='', label=label)
               for i, label in enumerate(class_labels)]
    plt.legend(handles=handles, title='Species')
    plt.show()
    

    您也可以使用 seaborn,尽管目前设置图例标签并不简单。

    import seaborn as sns
    
    sns.set()
    ax = sns.scatterplot(x=X_train[:, 0], y=X_train[:, 1], hue=Y_train, palette='viridis')
    ax.legend(ax.legend_.legendHandles, class_labels, title='Species')
    

    【讨论】:

      猜你喜欢
      • 2016-08-31
      • 2016-12-29
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-07-27
      • 1970-01-01
      相关资源
      最近更新 更多