【问题标题】:numpy.ndarray sorting to return indicesnumpy.ndarray 排序以返回索引
【发布时间】:2021-12-27 17:06:57
【问题描述】:
errors = [[ 0.,  9., 12.,  9., 14.,  5.,  4., 10.,  8.,  8.,  6.,  5.,  9.],
   [ 9.,  0., 22., 16., 11., 12.,  9., 21., 14., 11., 16., 15.,  9.],
   [12., 22.,  0., 18., 23., 16., 10., 22., 21., 13., 13., 13., 15.],
   [ 9., 16., 18.,  0., 11., 12.,  8., 19., 20., 11.,  7.,  9., 13.],
   [14., 11., 23., 11.,  0., 11.,  7., 18.,  9., 10.,  7.,  7., 14.],
   [ 5., 12., 16., 12., 11.,  0.,  7., 13., 15.,  5.,  8., 10.,  9.],
   [ 4.,  9., 10.,  8.,  7.,  7.,  0.,  8.,  8.,  3.,  4.,  7.,  4.],
   [10., 21., 22., 19., 18., 13.,  8.,  0., 18., 12., 14., 13., 11.],
   [ 8., 14., 21., 20.,  9., 15.,  8., 18.,  0.,  5., 11., 16., 10.],
   [ 8., 11., 13., 11., 10.,  5.,  3., 12.,  5.,  0.,  8.,  9.,  5.],
   [ 6., 16., 13.,  7.,  7.,  8.,  4., 14., 11.,  8.,  0., 11.,  7.],
   [ 5., 15., 13.,  9.,  7., 10.,  7., 13., 16.,  9., 11.,  0.,  4.],
   [ 9.,  9., 15., 13., 14.,  9.,  4., 11., 10.,  5.,  7.,  4.,  0.]])

上面是形状 (13,13) 的 numpy.ndarray,使用 13 个特征中的两个特征在某个分类任务中获得错误。

这里的任务是找到最小的可实现的错误和 实现这个最小误差的特征对

因为数据很少,最小的误差可以用眼睛看到,它的 3 和特征对是 (6,9) 或 (9,6)。

(0值的诊断线本身就是特征,所以不包括在内)。

我曾尝试使用 argsort 进行此操作,但它仅对每一行进行单独排序,我没有得到答案。

请帮忙。

【问题讨论】:

    标签: python numpy-ndarray np.argsort


    【解决方案1】:

    严格来说,您需要屏蔽的不是零,而是对角线值。谁知道呢,可能会有一对会给出绝配。所以我会这样做:

    # This modifies errors filling the diagonal with inf-s. Make a copy if you need to keep the original result intact.
    np.fill_diagonal(errors, np.inf)
    np.nonzero(errors == errors.min())
    

    【讨论】:

    • 如果一对或多对真的实现了零错误会发生什么。你将如何返回它的索引。正如我阅读文档时所说,np.nonzero 返回非零元素的索引。
    • 这里np.nonzero 返回errors == errors.min() 的非零元素的索引,这是一个具有非零元素的布尔数组,其中errors 等于errors.min()。如果errors.min() 为零,它将返回errors 的零元素的索引。
    【解决方案2】:

    如果我正确理解你的任务:

    • errors-array 中找到最小值
    • 找出这个最小错误的索引

    我使用min()nonzero() 实现了这一点。这是我对您的问题的解决方案:

    errors = np.array((
       [ 0.,  9., 12.,  9., 14.,  5.,  4., 10.,  8.,  8.,  6.,  5.,  9.],
       [ 9.,  0., 22., 16., 11., 12.,  9., 21., 14., 11., 16., 15.,  9.],
       [12., 22.,  0., 18., 23., 16., 10., 22., 21., 13., 13., 13., 15.],
       [ 9., 16., 18.,  0., 11., 12.,  8., 19., 20., 11.,  7.,  9., 13.],
       [14., 11., 23., 11.,  0., 11.,  7., 18.,  9., 10.,  7.,  7., 14.],
       [ 5., 12., 16., 12., 11.,  0.,  7., 13., 15.,  5.,  8., 10.,  9.],
       [ 4.,  9., 10.,  8.,  7.,  7.,  0.,  8.,  8.,  3.,  4.,  7.,  4.],
       [10., 21., 22., 19., 18., 13.,  8.,  0., 18., 12., 14., 13., 11.],
       [ 8., 14., 21., 20.,  9., 15.,  8., 18.,  0.,  5., 11., 16., 10.],
       [ 8., 11., 13., 11., 10.,  5.,  3., 12.,  5.,  0.,  8.,  9.,  5.],
       [ 6., 16., 13.,  7.,  7.,  8.,  4., 14., 11.,  8.,  0., 11.,  7.],
       [ 5., 15., 13.,  9.,  7., 10.,  7., 13., 16.,  9., 11.,  0.,  4.],
       [ 9.,  9., 15., 13., 14.,  9.,  4., 11., 10.,  5.,  7.,  4.,  0.]))
    
    min_error = errors[errors!=0].min()
    pairs = np.nonzero(errors == errors[errors!=0].min())
    

    这给了我预期的输出 min_error = 3pairs = (array([6, 9]), array([9, 6]))

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2013-08-19
      • 1970-01-01
      • 2016-05-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多