【问题标题】:Fastest algorithm for Kth smallest Element (or median) finding on 2 Dimensional Array?在二维数组上查找第 K 个最小元素(或中位数)的最快算法?
【发布时间】:2021-03-03 02:04:54
【问题描述】:

我看到很多关于相关主题的 SO 主题,但没有一个提供有效的方法。

我想在二维数组[1..M][1..N] 上找到k-th 最小元素(或中位数),其中每一行按升序排序,所有元素都是不同的。

我认为有O(M log MN) 解决方案,但我不知道如何实施。 (中位数的中位数或使用具有线性复杂性的分区是一些方法,但现在不知道了......)。

这是一个古老的谷歌面试问题,可以在Here上搜索。

但现在我想要提示或描述最有效的算法最快算法)。

我也读过一篇关于here的论文,但我不明白。

更新 1:找到了一个解决方案 here,但当维度为奇数时。

【问题讨论】:

  • 你可能会在the computer science stackexchange得到更多有见地的答案
  • 我在这里找到了非常完美的解决方案,我认为这更像是 CS。 @Stef 谢谢。
  • 您询问的是仅排序的行,还是排序的行和列。您的描述和界限仅对已排序的行看起来是合理的。但是您的所有链接都是到已排序的行和列。
  • @btilly 哇,谢谢我读了很多关于我的问题的好答案。仅对行进行了排序。我们对专栏一无所知。 (因为面试问题中没有提到)。我添加链接是因为我认为这是更具体的情况。不确定。
  • @Spektre 有一个解决方案可以在 O(M log MN) 中找到答案。我认为这是与您的时间相比的下限?不是吗?请在此处添加您的答案。

标签: java python arrays algorithm data-structures


【解决方案1】:

所以要解决这个问题,它有助于解决一个稍微不同的问题。我们想知道每行中总第 k 个截止点的上/下限。然后我们可以通过,验证下界或以下的事物数量k,并且它们之间只有一个值。

我想出了一个策略,在所有行中同时针对这些边界进行二进制搜索。作为二进制搜索,它“应该”通过O(log(n))。每遍涉及O(m)工作,共O(m log(n))次。我把应该放在引号中,因为我没有证据证明它实际上需要O(log(n)) 通行证。事实上,有可能在一行中过于激进,从其他行中发现所选择的枢轴已关闭,然后不得不退出。但我相信它几乎没有退缩,实际上是O(m log(n))

策略是跟踪下限、上限和中间的每一行。每次传递,我们都会创建一系列加权范围,从低到中,从中到高,从高到末端,权重是其中事物的数量,值是系列中的最后一个。然后,我们在该数据结构中找到第 k 个值(按权重),并将其用作在每个维度中进行二分搜索的枢轴。

如果枢轴超出从下到上的范围,我们会通过在纠正错误的方向上扩大区间来进行纠正。

当我们有正确的顺序时,我们就有了答案。

有很多边缘情况,所以盯着完整的代码可能会有所帮助。

我还假设每一行的所有元素都是不同的。如果不是,您可能会陷入无限循环。 (解决这意味着更多的边缘情况......)

import random

# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
    # This does quickselect for average O(len(pairs)).
    # Median of medians is deterministically the same, but a bit slower
    pivot = pairs[int(random.random() * len(pairs))][0]

    # Which side of our answer is the pivot on?
    weight_under_pivot = 0
    pivot_weight = 0
    for value, weight in pairs:
        if value < pivot:
            weight_under_pivot += weight
        elif value == pivot:
            pivot_weight += weight

    if weight_under_pivot + pivot_weight < k:
        filtered_pairs = []
        for pair in pairs:
            if pivot < pair[0]:
                filtered_pairs.append(pair)
        return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
    elif k <= weight_under_pivot:
        filtered_pairs = []
        for pair in pairs:
            if pair[0] < pivot:
                filtered_pairs.append(pair)
        return weighted_kth (k, filtered_pairs)
    else:
        return pivot

# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
    # The strategy is to discover the k'th value, and also discover where
    # that would be in each row.
    #
    # For each row we will track what we think the lower and upper bounds
    # are on where it is.  Those bounds start as the start and end and
    # will do a binary search.
    #
    # In each pass we will break each row into ranges from start to lower,
    # lower to mid, mid to upper, and upper to end.  Some ranges may be
    # empty.  We will then create a weighted list of ranges with the weight
    # being the length, and the value being the end of the list.  We find
    # where the k'th spot is in that list, and use that approximate value
    # to refine each range.  (There is a chance that a range is wrong, and
    # we will have to deal with that.)
    #
    # We finish when all of the uppers are above our k, all the lowers
    # one are below, and the upper/lower gap is more than 1 only when our
    # k'th element is in the middle.

    # Our data structure is simply [row, lower, upper, bound] for each row.
    data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
    is_search = True
    while is_search:
        pairs = []
        for row, lower, upper, bound in data:
            # Literal edge cases
            if 0 == upper:
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            elif lower == bound:
                pairs.append((row[lower], lower + 1))
            elif lower + 1 == upper: # No mid.
                pairs.append((row[lower], lower + 1))
                pairs.append((row[upper], 1))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))
            else:
                mid = (upper + lower) // 2
                pairs.append((row[lower], lower + 1))
                pairs.append((row[mid], mid - lower))
                pairs.append((row[upper], upper - mid))
                if upper < bound:
                    pairs.append((row[bound], bound - upper))

        pivot = weighted_kth(k, pairs)

        # Now that we have our pivot, we try to adjust our parameters.
        # If any adjusts we continue our search.
        is_search = False
        new_data = []
        for row, lower, upper, bound in data:
            # First cases where our bounds weren't bounds for our pivot.
            # We rebase the interval and either double the range.
            # - double the size of the range
            # - go halfway to the edge
            if 0 < lower and pivot <= row[lower]:
                is_search = True
                if pivot == row[lower]:
                    new_data.append((row, lower-1, min(lower+1, bound), bound))
                elif upper <= lower:
                    new_data.append((row, lower-1, lower, bound))
                else:
                    new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
            elif upper < bound and row[upper] <= pivot:
                is_search = True
                if pivot == row[upper]:
                    new_data.append((row, upper-1, upper+1, bound))
                elif lower < upper:
                    new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
                else:
                    new_data.append((row, upper, upper+1, bound))
            elif lower + 1 < upper:
                if upper == lower+2 and pivot == row[lower+1]:
                    new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
                else:
                    # We will split this interval.
                    is_search = True
                    mid = (upper + lower) // 2
                    if row[mid] < pivot:
                        new_data.append((row, mid, upper, bound))
                    elif pivot < row[mid] pivot:
                        new_data.append((row, lower, mid, bound))
                    else:
                        # We center our interval on the pivot
                        new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
            else:
                # We look like we found where the pivot would be in this row.
                new_data.append((row, lower, upper, bound))
        data = new_data # And set up the next search
    return pivot

【讨论】:

  • 所有元素都是不同的。真正的考虑。
  • @MokholiaPokholia 请告诉我,如果您发现它不能按承诺工作。
  • 非常好,让我花几分钟检查一下。我首先想到的一个问题是,在深入了解复杂性之前,我们如何首先证明复杂性?
  • 对我来说是一个小误解。你的时间复杂度是多少?
  • @MokholiaPokholia 我没有证据。但是。我相信时间复杂度是O(m log(n))。我有另一个变体可以处理重复并具有稍微更好的行为,但我再次没有性能证明。 (不同之处在于将间隔切成三分之一,使用范围技巧来建立第 k 个值的上限/下限。然后丢弃绝对不在范围内的行部分。)
【解决方案2】:

已添加另一个答案以提供实际解决方案。由于 cmets 上有一个相当大的兔子洞,这个已经留下了。


我相信最快的解决方案是 k-way 合并算法。这是一种O(N log K) 算法,将K 排序列表与总共N 项合并到一个大小为N 的排序列表中。

https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge

给定一个MxN 列表。这最终是O(MNlog(M))。但是,这是为了对整个列表进行排序。由于您只需要第一个K 最小的项目而不是所有N*M,因此性能是O(Klog(M))。假设O(K) &lt;= O(M),这比您要寻找的要好得多。

虽然这假设您有N 大小为M 的排序列表。如果您实际上有M 大小为N 的排序列表,则可以通过更改循环数据的方式轻松处理(请参阅下面的伪代码),尽管这确实意味着性能是O(K log(N))

k-way 合并只是将每个列表的第一项添加到堆或其他具有 O(log N) insert 和 O(log N) find-mind 的数据结构中。

k-way 合并的伪代码看起来有点像这样:

  1. 对于每个排序列表,将第一个值插入到数据结构中,并通过某种方式确定该值来自哪个列表。 IE:您可以在数据结构中插入[value, row_index, col_index] 而不仅仅是value。这还可以让您轻松处理对列或行的循环。
  2. 从数据结构中删除最小值并附加到排序列表中。
  3. 假设步骤#2 中的项目来自列表I,将列表I 中的下一个最小值添加到数据结构中。 IE:如果值为row 5 col 4 (data[5][4])。然后,如果您使用行作为列表,那么下一个值将是row 5 col 5 (data[5][5])。如果您使用列,则下一个值为row 6 col 4 (data[6][4])。像 #1 一样将下一个值插入到数据结构中(即:[value, row_index, col_index]
  4. 根据需要返回第 2 步。

根据您的需要,执行步骤 2-4 K 次。

【讨论】:

【解决方案3】:

可能是我遗漏了一些东西,但是如果您的 NxM 矩阵 AM 行已经按升序排序,没有重复元素,那么 k-th 行的最小值只是选择 k-行中的第一个元素,即O(1)。要移动到 2D,您只需选择 k-th 列,将其升序排列 O(M.log(M)) 并再次选择 k-th 元素导致 O(N.log(N))

  1. 让我们有矩阵A[N][M]

    元素是A[column][row]

  2. Ak-th 列进行升序排列O(M.log(M))

    所以对A[k][i] 进行排序,其中i = { 1,2,3,...M } 升序

  3. 选择A[k][k]作为结果

如果您想要 A 中所有元素中的第 k 个最小元素,那么您需要以类似于归并排序的形式利用已排序的行。

  1. 创建空列表c[] 用于保存k 最小值

  2. 处理列

  3. 创建临时数组b[]

    它保存已处理的列快速排序升序O(N.log(N))

  4. 合并c[]b[],这样c[] 可以容纳k 最小值

    使用临时数组d[] 将导致O(k+n)

  5. 如果在合并期间未使用来自b 的任何项目,则停止处理列

    这可以通过添加标志数组f 来完成,该数组将保存在合并期间从b,c 获取值的位置,然后检查是否从b 获取任何值

  6. 输出c[k-1]

如果我们认为k小于M,那么最终的复杂度是O(min(k,M).N.log(N)),我们可以重写为O(k.N.log(N)),否则O(M.N.log(N))。同样平均而言,要迭代的列数更不可能~(1+(k/N)),因此平均复杂度为~O(N.log(N)),但这只是我的疯狂猜测,可能是错误的。

这里是小 C++/VCL 示例:

//$$---- Form CPP ----
//---------------------------------------------------------------------------
#include <vcl.h>
#pragma hdrstop
#include "Unit1.h"
#include "sorts.h"
//---------------------------------------------------------------------------
#pragma package(smart_init)
#pragma resource "*.dfm"
TForm1 *Form1;
//---------------------------------------------------------------------------
const int m=10,n=8; int a[m][n],a0[m][n]; // a[col][row]
//---------------------------------------------------------------------------
void generate()
    {
    int i,j,k,ii,jj,d=13,b[m];
    Randomize();
    RandSeed=0x12345678;
    // a,a0 = some distinct pseudorandom values (fully ordered asc)
    for (k=Random(d),j=0;j<n;j++)
     for (i=0;i<m;i++,k+=Random(d)+1)
      { a0[i][j]=k; a[i][j]=k; }
    // schuffle a
    for (j=0;j<n;j++)
     for (i=0;i<m;i++)
        {
        ii=Random(m);
        jj=Random(n);
        k=a[i][j]; a[i][j]=a[ii][jj]; a[ii][jj]=k;
        }
    // sort rows asc
    for (j=0;j<n;j++)
        {
        for (i=0;i<m;i++) b[i]=a[i][j];
        sort_asc_quick(b,m);
        for (i=0;i<m;i++) a[i][j]=b[i];
        }

    }
//---------------------------------------------------------------------------
int kmin(int k) // k-th min from a[m][n] where a rows are already sorted
    {
    int i,j,bi,ci,di,b[n],*c,*d,*e,*f,cn;
    c=new int[k+k+k]; d=c+k; f=d+k;
    // handle edge cases
    if (m<1) return -1;
    if (k>m*n) return -1;
    if (m==1) return a[0][k];
    // process columns
    for (cn=0,i=0;i<m;i++)
        {
        // b[] = sorted_asc a[i][]
        for (j=0;j<n;j++) b[j]=a[i][j];     // O(n)
        sort_asc_quick(b,n);                // O(n.log(n))
        // c[] = c[] + b[] asc sorted and limited to cn size
        for (bi=0,ci=0,di=0;;)              // O(k+n)
            {
                 if ((ci>=cn)&&(bi>=n)) break;
            else if (ci>=cn)     { d[di]=b[bi]; f[di]=1; bi++; di++; }
            else if (bi>= n)     { d[di]=c[ci]; f[di]=0; ci++; di++; }
            else if (b[bi]<c[ci]){ d[di]=b[bi]; f[di]=1; bi++; di++; }
            else                 { d[di]=c[ci]; f[di]=0; ci++; di++; }
            if (di>k) di=k;
            }
        e=c; c=d; d=e; cn=di;
        for (ci=0,j=0;j<cn;j++) ci|=f[j];   // O(k)
        if (!ci) break;
        }
    k=c[k-1];
    delete[] c;
    return k;
    }
//---------------------------------------------------------------------------
__fastcall TForm1::TForm1(TComponent* Owner):TForm(Owner)
    {
    int i,j,k;
    AnsiString txt="";

    generate();

    txt+="a0[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a0[i][j]);

    txt+="\r\na[][]\r\n";
    for (j=0;j<n;j++,txt+="\r\n")
     for (i=0;i<m;i++) txt+=AnsiString().sprintf("%4i ",a[i][j]);

    k=20;
    txt+=AnsiString().sprintf("\r\n%ith smallest from a0 = %4i\r\n",k,a0[(k-1)%m][(k-1)/m]);
    txt+=AnsiString().sprintf("\r\n%ith smallest from a  = %4i\r\n",k,kmin(k));

    mm_log->Lines->Add(txt);
    }
//-------------------------------------------------------------------------

忽略 VCL 的东西。函数 generate 计算 a0, a 矩阵,其中 a0 已完全排序,a 仅对行进行排序,并且所有值都是不同的。函数kmin 是上面描述的算法,从a[m][n] 返回第k 个最小值对于排序,我使用了这个:

template <class T> void sort_asc_quick(T *a,int n)
    {
    int i,j; T a0,a1,p;
    if (n<=1) return;                                   // stop recursion
    if (n==2)                                           // edge case
        {
        a0=a[0];
        a1=a[1];
        if (a0>a1) { a[0]=a1; a[1]=a0; }                // condition
        return;
        }
    for (a0=a1=a[0],i=0;i<n;i++)                        // pivot = midle (should be median)
        {
        p=a[i];
        if (a0>p) a0=p;
        if (a1<p) a1=p;
        } if (a0==a1) return; p=(a0+a1+1)/2;            // if the same values stop
    if (a0==p) p++;
    for (i=0,j=n-1;i<=j;)                               // regroup
        {
        a0=a[i];
        if (a0<p) i++; else { a[i]=a[j]; a[j]=a0; j--; }// condition
        }
    sort_asc_quick(a  ,  i);                            // recursion a[]<=p
    sort_asc_quick(a+i,n-i);                            // recursion a[]> p
    }

这里是输出:

a0[][]
  10   17   29   42   54   66   74   85   90  102 
 112  114  123  129  142  145  146  150  157  161 
 166  176  184  191  195  205  213  216  222  224 
 226  237  245  252  264  273  285  290  291  296 
 309  317  327  334  336  349  361  370  381  390 
 397  398  401  411  422  426  435  446  452  462 
 466  477  484  496  505  515  522  524  525  530 
 542  545  548  553  555  560  563  576  588  590 

a[][]
 114  142  176  264  285  317  327  422  435  466 
 166  336  349  381  452  477  515  530  542  553 
 157  184  252  273  291  334  446  524  545  563 
  17  145  150  237  245  290  370  397  484  576 
  42  129  195  205  216  309  398  411  505  560 
  10  102  123  213  222  224  226  390  496  555 
  29   74   85  146  191  361  426  462  525  590 
  54   66   90  112  161  296  401  522  548  588 

20th smallest from a0 =  161

20th smallest from a  =  161

这个例子只迭代了 5 列...

【讨论】:

  • 非常好,这种方法如何实现 O(M log MN)?
  • @MounaMokhiab 我编辑了我的答案......添加了我刚刚忙忙碌碌的例子......我和你一样认为部分排序 a 排序会导致 O(M.log(M.N)) 但看起来我是错误,因为它会导致O(M.N.log(N))。但是我做了一些调整(因为我们不需要对整个矩阵进行排序,只对前 k 个最小元素进行排序),因此复杂度差异 ....
  • 确保我们有 M*N 矩阵意味着 M 行和 N 列这样 M 行被排序并且没有重复的元素。
  • 你在 OP 中看到的肯定是这个定义。
【解决方案4】:

似乎最好的方法是在越来越大的块中进行 k 路合并。 k-way 合并试图建立一个排序列表,但我们不需要它排序,也不需要考虑每个元素。相反,我们将创建一个半排序的区间。间隔将被排序,但仅限于最高值。

https://en.wikipedia.org/wiki/K-way_merge_algorithm#k-way_merge

我们使用与 k 路合并相同的方法,但有所不同。基本上它旨在间接构建一个半排序的子列表。例如,不是找到 [1,2,3,4,5,6,7,8,10] 来确定 K=10,而是找到类似 [(1,3),(4,6), (7,15)]。使用 K-way 合并,我们每次从每个列表中考虑 1 个项目。在这种悬停方法中,当从给定列表中拉取时,我们首先要考虑 Z 个项目,然后是 2 * Z 个项目,然后是 2 * 2 * Z 个项目,所以第 i 次是 2^i * Z 个项目。给定一个 MxN 矩阵,这意味着我们需要从列表 M 次中提取 O(log(N)) 个项目。

  1. 对于每个排序列表,将第一个K 子列表插入到数据结构中,并通过某种方式确定值来自哪个列表。我们希望数据结构使用我们插入的子列表中的最大值。在这种情况下,我们需要类似 [max_value of sublist, row index, start_index, end_index]。 O(m)
  2. 从数据结构中删除最小值(现在是值列表)并附加到排序列表中。 O(log (m))
  3. 假设第 2 步中的项目来自列表 I,将列表 2^i * Z 中的下一个 2^i * Z 值添加到第 i 次从该特定列表中提取的数据结构中(基本上只是数字的两倍存在于刚刚从数据结构中删除的子列表中)。 O(log m)
  4. 如果半排序子列表的大小大于 K,使用二分查找查找第 k 个值。 O(log N))。如果数据结构中剩余任何子列表,其中最小值小于 k。转到第 1 步,将列表作为输入,新的 Kk - (size of semi-sorted list)
  5. 如果半排序子列表的大小等于K,则返回半排序子列表中的最后一个值,这是第K个值。
  6. 如果半排序子列表的大小小于 K,则返回步骤 2。

至于性能。让我们在这里看看:

  • 采用O(m log m) 将初始值添加到数据结构中。
  • 它最多需要考虑O(m) 子列表,每个子列表需要O(log n) 时间用于`O(m log n)。
  • 它需要在最后执行二分查找,O(log m),如果不确定 K 的值是多少(步骤 4),它可能需要将问题简化为递归子列表,但我不认为这会影响大 O。编辑:我相信这只是在最坏的情况下增加了另一个 O(mlog(n)),这对大 O 没有影响。

所以看起来它是O(mlog(m) + mlog(n)) 或只是O(mlog(mn))

作为优化,如果 K 高于NM/2,则在考虑最小值时考虑最大值,在考虑最大值时考虑最小值。这将大大提高 K 接近 NM 时的性能。

【讨论】:

    【解决方案5】:

    btillyNuclearman 的答案提供了两种不同的方法,一种binary search 和一种k-way merge 的行。

    我的建议是结合这两种方法。

    • 如果 k 很小(比如说小于 M 乘以 2 或 3)或很大(对于 simmetry,接近 N x M) 足够了,找到对行进行 M 路合并的 kth 元素。当然,我们不应该合并所有元素,只合并前k

    • 否则,开始检查矩阵的第一列和最后一列,以找到最小值(女巫在第一列)和最大值(在最后一列)。

    • 将第一个关键值估计为这两个值的线性组合。类似pivot = min + k * (max - min) / (N * M)

    • 在每一行中执行二进制搜索以确定最后一个元素(更接近)不大于枢轴。简单地推导出小于或等于枢轴的元素数。比较那些与 k 的总和将知道选择的枢轴值是太大还是太小,让我们相应地修改它。跟踪所有行之间的最大值,它可能是第 k 个元素或仅用于评估下一个枢轴。如果我们将所述和视为枢轴的函数,那么现在的数值问题是找到sum(pivot) - k 的零点,这是一个单调(离散)函数。在最坏的情况下,我们可以使用二分法(对数复杂度)或割线法。

    • 我们可以理想地将每一行划分为三个范围:

      • 在左侧,元素 肯定 小于或等于 kth 元素。
      • 在中间,未确定的范围。
      • 在右侧,元素 肯定 大于 kth 元素。
    • 不确定的范围将在每次迭代时减少,最终对于大多数行变为空。在某些时候,仍然在未确定范围内的元素数量,分散在整个矩阵中,将小到足以诉诸这些范围的单个 M 路合并。

    • 如果我们将单次迭代的时间复杂度视为O(MlogN)M 二分搜索,我们需要将其乘以枢轴收敛到的值所需的迭代次数kth-元素,可以是O(logNM)。如果 N > M,则此总和为 O(MlogNlogM)O(MlogNlogN)

    • 注意,如果算法用于求中位数,最后一步是M路合并,很容易找到(k + 1)th - 元素也是。

    【讨论】:

    • 有趣的算法。我正在考虑做类似的事情,但不确定它是否能正常工作(或性能更高),所以只是坚持使用 k-way 合并。我相信分区位是我使之成为可能所缺少的,因此为解决这个问题而感到自豪。似乎是一种可靠的方法,但不能 100% 确定它是正确的,但似乎足够接近可以使用。
    猜你喜欢
    • 1970-01-01
    • 2015-01-15
    • 2021-01-13
    • 1970-01-01
    • 2015-12-02
    • 2021-09-05
    • 1970-01-01
    • 2019-08-28
    • 2011-04-12
    相关资源
    最近更新 更多