【问题标题】:Seaborn Heatmap - Display the heatmap only if values are above given thresholdSeaborn Heatmap - 仅当值高于给定阈值时才显示热图
【发布时间】:2021-05-04 23:02:29
【问题描述】:

下面的python代码显示句子相似度,它使用Universal Sentence Encoder来实现。

from absl import logging

import tensorflow as tf

import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns

module_url = "https://tfhub.dev/google/universal-sentence-encoder/4" 
model = hub.load(module_url)
print ("module %s loaded" % module_url)
def embed(input):
  return model(input)


def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  print(corr)
  sns.set(font_scale=2.4)
  plt.subplots(figsize=(40,30))
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlGnBu",linewidths=1.0)
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")

def run_and_plot(messages_):
  message_embeddings_ = embed(messages_)
  plot_similarity(messages_, message_embeddings_, 90)


messages = [
"I want to know my savings account balance",
"Show my bank balance",
"Show me my account",
"What is my bank balance",
"Please Show my bank balance"    

]

run_and_plot(messages)

输出显示为热图,如下所示,同时打印值

我只想关注看起来非常相似的句子,但是当前的热图显示了所有值。

所以

  1. 有没有一种方法可以仅查看范围大于 0.6 且小于 0.999 的值的热图?

  2. 是否可以打印位于给定范围内的匹配值对,即 0.6 和 0.99? 谢谢, 罗希特

【问题讨论】:

  • 将 vmin/vmax 设置在所需的范围内,并将 set clipping colors 上下设置为白色?

标签: python python-3.x tensorflow seaborn heatmap


【解决方案1】:

根据您的问题更新,这是一个修订版。显然,在网格中,不能删除单个单元格。但是我们可以大幅减少热图以仅显示相关的值对。热图中存在的随机分散的重要值越多,这种效果就越不明显。

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from copy import copy
import seaborn as sns

#semi-random data generation 
labels = list("ABCDE")
np.random.seed(123)
df = pd.DataFrame(np.random.randint(1, 100, (20, 5)))
df.columns = labels
df.A = df.B - df.D
df.C = df.B + df.A
df.E = df.A + df.C

#your correlation array
corr = df.corr().to_numpy()
print(corr)

#conditions for filtering 0.6<=r<=0.9
val_min = 0.6
val_max = 0.99

#plotting starts here
sns.set(font_scale=2.4)
#two axis objects just for comparison purposes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,8))

#define the colormap with clipping values
my_cmap = copy(plt.cm.YlGnBu)
my_cmap.set_over("white")
my_cmap.set_under("white")

#ax1 - full set of conditions as in the initial version 
g1 = sns.heatmap(corr,
    xticklabels=labels,
    yticklabels=labels,
    vmin=val_min,
    vmax=val_max,
    cmap=my_cmap,
    linewidths=1.0,
    linecolor="grey",
    ax=ax1)

g1.set_title("Entire heatmap")

#ax2 - remove empty rows/columns
# use only lower triangle
corr = np.tril(corr)

#delete columns where all elements do not fulfill the conditions
ind_x,  = np.where(np.all(np.logical_or(corr<val_min, corr>val_max), axis=0))
corr = np.delete(corr, ind_x, 1)
#update x labels
map_labels_x = [item for i, item in enumerate(labels) if i not in ind_x]
    
#now the same for rows 
ind_y, = np.where(np.all(np.logical_or(corr<val_min, corr>val_max), axis=1))
corr = np.delete(corr, ind_y, 0)
#update y labels
map_labels_y = [item for i, item in enumerate(labels) if i not in ind_y]

#plot heatmap
g2 = sns.heatmap(corr,
    xticklabels=map_labels_x,
    yticklabels=map_labels_y,
    vmin=val_min,
    vmax=val_max,
    cmap=my_cmap,
    linewidths=1.0,
    linecolor="grey", ax=ax2)

g2.set_title("Reduced heatmap")

plt.show()

样本输出:

左侧,原始方法显示热图的所有元素。对,只保留相关的对。 问题(以及代码)排除了显着的负相关,例如 -0.95。如果不打算这样做,则应使用np.abs()

初步回答
我很惊讶没有人提供一个独立的解决方案,所以这里有一个:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from copy import copy
import seaborn as sns

labels = list("ABCDE")
#semi-random data
np.random.seed(123)
df = pd.DataFrame(np.random.randint(1, 100, (20, 5)))
df.columns = labels
df.A = df.B - df.D
df.E = df.A + df.C

corr = df.corr()
sns.set(font_scale=2.4)
plt.subplots(figsize=(10,8))

#define the cmap with clipping values
my_cmap = copy(plt.cm.YlGnBu)
my_cmap.set_over("white")
my_cmap.set_under("white")

g = sns.heatmap(corr,
    xticklabels=labels,
    yticklabels=labels,
    vmin=0.5,
    vmax=0.9,
    cmap=my_cmap,
    linewidths=1.0,
    linecolor="grey")

g.set_xticklabels(labels, rotation=60)
g.set_title("Important!")

plt.show()

示例输出:

【讨论】:

  • 这部分解决了问题。请注意,热图大小仍然保持不变。我也想根据最小值和最大值减小热图大小。这可能吗?
  • 另外,是否可以只打印在给定阈值内匹配的值对?
  • 那么,在这种情况下,应该从热图中删除 D 吗?这从描述中并不清楚。这是可能的,但我今天没有时间,所以也许其他人会在此期间发布解决方案。
  • 谢谢@Mr。 T ,是否可以只打印在给定阈值内匹配的值对?
  • 我现在有时间研究它,更新版本尽可能地删除了所有空白空间。第二个问题 - 在您定义的条件下打印所有相关对 - 与 seaborn/heatmap 主题明显不同,应单独提出。顺便说一句,我希望您知道您的问题不包括显着的负相关(初始版本中的 D 列),这是一个有意的功能,而不是错误。包含负相关系数将是另一个主题,因为您还必须更改颜色图
猜你喜欢
  • 2019-11-28
  • 2022-08-23
  • 2017-09-28
  • 1970-01-01
  • 2021-12-20
  • 1970-01-01
  • 2023-04-03
  • 2023-03-04
  • 2018-03-24
相关资源
最近更新 更多