【发布时间】:2021-05-04 23:02:29
【问题描述】:
下面的python代码显示句子相似度,它使用Universal Sentence Encoder来实现。
from absl import logging
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
model = hub.load(module_url)
print ("module %s loaded" % module_url)
def embed(input):
return model(input)
def plot_similarity(labels, features, rotation):
corr = np.inner(features, features)
print(corr)
sns.set(font_scale=2.4)
plt.subplots(figsize=(40,30))
g = sns.heatmap(
corr,
xticklabels=labels,
yticklabels=labels,
vmin=0,
vmax=1,
cmap="YlGnBu",linewidths=1.0)
g.set_xticklabels(labels, rotation=rotation)
g.set_title("Semantic Textual Similarity")
def run_and_plot(messages_):
message_embeddings_ = embed(messages_)
plot_similarity(messages_, message_embeddings_, 90)
messages = [
"I want to know my savings account balance",
"Show my bank balance",
"Show me my account",
"What is my bank balance",
"Please Show my bank balance"
]
run_and_plot(messages)
我只想关注看起来非常相似的句子,但是当前的热图显示了所有值。
所以
-
有没有一种方法可以仅查看范围大于 0.6 且小于 0.999 的值的热图?
-
是否可以打印位于给定范围内的匹配值对,即 0.6 和 0.99? 谢谢, 罗希特
【问题讨论】:
-
将 vmin/vmax 设置在所需的范围内,并将 set clipping colors 上下设置为白色?
标签: python python-3.x tensorflow seaborn heatmap