【问题标题】:matplotlib: change axis ticks of ndim histogram plotted with seaborn.heatmapmatplotlib:更改用 seaborn.heatmap 绘制的 ndim 直方图的轴刻度
【发布时间】:2019-09-02 22:31:51
【问题描述】:

动机:

我正在尝试可视化包含许多 n 维向量的数据集(假设我有 10k 个向量,n=300 维)。我想做的是为每个 n 维计算一个直方图,并将其绘制为 bins*n 热图中的一条线。

到目前为止,我已经得到了这个:

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns

# sample data:
vectors = np.random.randn(10000, 300) + np.random.randn(300)

def ndhist(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)
    fig = plt.figure(figsize=(16, 9))
    sns.heatmap(hists)
    axes = fig.gca()
    axes.set(ylabel='dimensions', xlabel='values')
    print(dims)
    print(limits)

ndhist(vectors)

这会生成以下输出:

300
(-6.538069472429366, 6.52159540162285)

问题/疑问:

如何更改坐标轴刻度?

  • 对于 y 轴,我想简单地把它改回 matplotlib 的默认值,所以它会选择像 0, 50, 100, ..., 250 这样的漂亮刻度(299300 的奖励积分)
  • 对于 x 轴,我想将显示的 bin 索引转换为 bin(左)边界,然后,如上所述,我想将其改回 matplotlib 的一些“不错”刻度的默认选择,例如-5, -2.5, 0, 2.5, 5(积分还包括实际限制-6.538, 6.522

自己的解决方案尝试:

我已经尝试了很多类似以下的方法:

def ndhist_axlabels(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)
    fig = plt.figure(figsize=(16, 9))
    sns.heatmap(hists, yticklabels=False, xticklabels=False)
    axes = fig.gca()
    axes.set(ylabel='dimensions', xlabel='values')
    #plt.xticks(np.linspace(*limits, len(bins)), bins)
    plt.xticks(range(len(bins)), bins)
    axes.xaxis.set_major_locator(matplotlib.ticker.AutoLocator())
    plt.yticks(range(dims+1), range(dims+1))
    axes.yaxis.set_major_locator(matplotlib.ticker.AutoLocator())
    print(dims)
    print(limits)

ndhist_axlabels(vectors)

但是,正如您所见,坐标轴标签非常错误。我的猜测是范围或限制存储在原始轴中的某个位置,但在切换回AutoLocator 时会丢失。非常感谢您朝正确的方向轻推。

【问题讨论】:

    标签: python matplotlib histogram seaborn heatmap


    【解决方案1】:

    也许你想多了。要绘制图像数据,可以使用imshow 并免费获取标记和格式。

    import numpy as np
    from matplotlib import pyplot as plt
    
    # sample data:
    vectors = np.random.randn(10000, 300) + np.random.randn(300)
    
    def ndhist(vectors, bins=500):
        limits = (vectors.min(), vectors.max())
        hists = []
        dims = vectors.shape[1]
    
        for dim in range(dims):
            h, _ = np.histogram(vectors[:, dim], bins=bins, range=limits)
            hists.append(h)
        hists = np.array(hists)
    
        fig, ax = plt.subplots(figsize=(16, 9))
    
        extent = [limits[0], limits[-1], hists.shape[0]-0.5, -0.5]  
        im = ax.imshow(hists, extent=extent, aspect="auto")
        fig.colorbar(im)
    
        ax.set(ylabel='dimensions', xlabel='values')
    
    ndhist(vectors)
    plt.show()
    

    【讨论】:

      【解决方案2】:

      如果您阅读docs,您会注意到xticklabels/yticklabels 参数被重载,因此如果您提供整数而不是字符串,它会将参数解释为xtickevery/@ 987654329@ 并仅在相应位置放置刻度。因此,在您的情况下,seaborn.heatmap(hists, yticklabels=50) 解决了您的 y 轴问题。

      关于您的 xtick 标签,我会明确地提供它们:

      xtickevery = 50 
      xticklabels = ['{:.1f}'.format(b) if ii%xtickevery == 0 else '' for ii, b in enumerate(bins)]
      sns.heatmap(hists, yticklabels=50, xticklabels=xticklabels)
      

      【讨论】:

      • yupp,也做到了这一点,但这不是 matplotlib 为尺寸本身拍照的“默认”...(如果我有几个 1000 怎么办?),但是,是的,可以忍受y 轴。但是对于 x 轴,即使使用 xticklabels=1,然后替换它们也不起作用:-/
      • 对不起,我还没说完就按了“post answer”。修改答案以修复两个轴。
      • 嗯,仍然是一个固定的宽度,导致奇怪的步长和绘制所有的刻度很慢......:-/
      【解决方案3】:

      终于想出了一个适合我的版本,并基于一些简单的线性映射使用AutoLocator...

      def ndhist(vectors, bins=1000, title=None):
          t = time.time()
          limits = (vectors.min(), vectors.max())
          hists = []
          dims = vectors.shape[1]
          for dim in range(dims):
              h, bs = np.histogram(vectors[:, dim], bins=bins, range=limits)
              hists.append(h)
          hists = np.array(hists)
      
          fig = plt.figure(figsize=(16, 12))
          sns.heatmap(
              hists,
              yticklabels=50,
              xticklabels=False
          )
      
          axes = fig.gca()
          axes.set(
              ylabel=f'dimensions ({dims} total)',
              xlabel=f'values (min: {limits[0]:.4g}, max: {limits[1]:.4g}, {bins} bins)',
              title=title,
          )
      
          def val_to_idx(val):
              # calc (linearly interpolated) index loc for given val
              return bins*(val - limits[0])/(limits[1] - limits[0])
          xlabels = [round(l, 3) for l in limits] + [
              v for v in matplotlib.ticker.AutoLocator().tick_values(*limits)[1:-1]
          ]
          # drop auto-gen labels that might be too close to limits
          d = (xlabels[4] - xlabels[3])/3
          if (xlabels[1] - xlabels[-1]) < d:
              del xlabels[-1]
          if (xlabels[2] - xlabels[0]) < d:
              del xlabels[2]
          xticks = [val_to_idx(val) for val in xlabels]
          axes.set_xticks(xticks)
          axes.set_xticklabels([f'{l:.4g}' for l in xlabels])
      
          plt.show()
          print(f'histogram generated in {time.time() - t:.2f}s')
      
      ndhist(np.random.randn(100000, 300), bins=1000, title='randn')
      

      感谢 Paul his answer 给了我这个想法。

      如果有更简单或更优雅的解决方案,我仍然会感兴趣。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2021-03-31
        • 2012-05-08
        • 1970-01-01
        • 2011-06-13
        • 1970-01-01
        相关资源
        最近更新 更多