【问题标题】:Numpy NdArray - How to set threshold and then sort valuesNumpy NdArray - 如何设置阈值然后对值进行排序
【发布时间】:2018-07-14 23:44:18
【问题描述】:

我有一个名为 'sim' 的 numpy ndarray (4 x 4) 表示 4 个项目 (a,b,c,d) 之间的相似度值。

array([[ 1.        ,  0.        ,  0.5547002 ,  0.73960026],
       [ 0.        ,  1.        ,  0.        ,  0.66666667],
       [ 0.5547002 ,  0.        ,  1.        ,  0.33333333],
       [ 0.73960026,  0.66666667,  0.33333333,  1.        ]])

dataset_u 是一个包含 [a,b,c,d] 的列表 以下代码对数组进行排序,然后根据项目 a、b、c、d 的相似度值确定前 3 个项目(related_count)。

related_count =3
dataidx = np.asarray(dataset_u) # a,b,c,d
indices = np.argsort(-sim, axis=1)
result = np.hstack((dataidx[:, None], dataidx[indices]))
m1 = result.shape[0]
mask = np.c_[[True] * m1, result[:, 1:] != result[:, 0, None]]
final_mat = result[mask].reshape(m1, -1)
dfdownload = pd.DataFrame(final_mat[:, 1:related_count], index=final_mat[:, 0])

df下载:

如何修改上述代码,使其在对数组进行排序之前只考虑 >=0.5 的值? 例如,对于项目“a”,预期的相关项目是“d”、“c”,而对于项目“b”,其相关项目只有“d”(0.66666667)。

【问题讨论】:

  • 感谢您的评论,我已经更新了代码。 sim 指的是相似度 ndarray,dataset_u 是 [a,b,c,d] 的列表
  • 你的预期输出是什么?

标签: python pandas sorting numpy threshold


【解决方案1】:

我对@9​​87654321@ 和pandas 都很陌生,所以这可能不是最好的方法,我只是希望它能引导您找到更好的解决方案。

sim_copy = sim.copy()
sim_copy[sim_copy <= 0.5] = 0
bool_sim = np.asarray(sim_copy, dtype=bool)
dfdownload.mask(~bool_sim[:, :-1])
# -1 can be replaced with related_count, but its value seems wrong.

输出

     0    1    2
a    d  NaN    b
b  NaN    a  NaN
c    a  NaN    d
d    a    b  NaN

附带说明,related_count 的值应该是 4 而不是 3,但我也不太确定 :)。

【讨论】:

    猜你喜欢
    • 2018-05-12
    • 2020-03-19
    • 1970-01-01
    • 2020-09-01
    • 2012-04-12
    • 1970-01-01
    相关资源
    最近更新 更多