【问题标题】:Java implementation of MergeSort; cannot find the bug合并排序的 Java 实现;找不到错误
【发布时间】:2020-03-26 15:07:50
【问题描述】:

好的,这是那些令人绝望的问题之一。我正在尝试实现自下而上的 MS 来排序和整数数组。但是,老天爷,我似乎找不到错误...

import java.util.Scanner;

public class A2 {

    public static boolean less(Integer v, Integer w) {
        return v.compareTo(w) < 0;
    }

    public static void sort(int[] a) {
        int N = a.length;
        int[] aux = new int[N];
        for (int sz = 1; sz < N; sz = sz + sz)
            for (int lo = 0; lo < N - sz; lo += sz + sz)
                merge(a, aux, lo, lo + sz - 1, Math.min(lo + sz + sz - 1, N - 1));
    }

    public static void merge(int[] a, int aux[], int lo, int mid, int hi) {
        int i = lo;
        int j = mid + 1;

        for (int k = lo; k <= hi; k++)
            aux[k] = a[k];

        for (int k = lo; k <= hi; k++) 
            if (i > mid)
                a[k] = aux[j++];
            else if (j > hi)
                a[k] = aux[i++];
            else if (less(aux[j], aux[i]))
                a[k] = a[j++];
            else
                a[k] = a[i++];

    }

    public static void main(String[] args) {
        int next = 0;
        Scanner scanner = new Scanner(System.in);
        int size = Integer.parseInt(scanner.nextLine());
        int[] v = new int[size];
        String s = scanner.nextLine();
        scanner.close();
        String[] sa = s.split("[\\s]+");
        while (next < size) {
            v[next] = Integer.parseInt(sa[next]);
            next ++;
        }
        for (Integer i : v)
            System.out.print(i + " ");
        System.out.println();
        System.out.println("----------------------------------");
        sort(v);
        for (int i = 0; i < size; i++)
            System.out.print(v[i] + " ");
        System.out.println();
    }
}

main 函数中,我打印了数组的元素,只是为了确定问题出在排序上。第一个数字只是数组的大小。该错误位于sort()merge() 中。 以下是一些示例输出:

9
10 45 20 5 -6 80 99 -4 0
10 45 20 5 -6 80 99 -4 0 
----------------------------------
-6 -4 -4 -6 -4 -4 -6 0 99 

6
6 7 3 2 4 1
6 7 3 2 4 1 
----------------------------------
1 1 1 4 6 7 

5
6 5 2 3 4
6 5 2 3 4 
----------------------------------
2 3 4 5 6 

最后一个看起来还不错。

请帮帮我,我一直在转来转去,似乎找不到错误。

【问题讨论】:

标签: java algorithm sorting debugging mergesort


【解决方案1】:

问题出在 merge() 方法中:在循环的最后 2 种情况下,您从 a 复制值,而不是从 aux 复制值。复制 a[j++] 时没有问题,但复制 a[i++] 时,该值可能已被覆盖。

考虑到右侧切片中的值是复制后才写入的,所以只需要保存左侧切片即可。

这是一个经过简化的修改版本:

    public static void merge(int[] a, int aux[], int lo, int mid, int hi) {
        int i = lo;
        int j = mid + 1;

        for (int k = lo; k <= mid; k++)  // save a[lo..mid] to aux
            aux[k] = a[k];

        for (int k = lo; k <= hi; k++) {
            if (i > mid)
                a[k] = a[j++];
            else if (j > hi)
                a[k] = aux[i++];
            else if (less(a[j], aux[i]))
                a[k] = a[j++];
            else
                a[k] = aux[i++];
        }
    }

请注意,将mid 视为正确切片的开头并将hi 视为切片末尾之后的索引会更不容易出错。 sort() 循环会更简单,没有棘手的 +/-1 调整。顺便说一句,您的版本中的内部循环测试偏离了一个,尽管除了效率低下没有任何后果。应该是:

for (int lo = 0; lo < N - sz - 1; lo += sz + sz)

这是一个包含/排除切片和组合测试的进一步简化实现:

    public static void sort(int[] a) {
        int N = a.length;
        int[] aux = new int[N];
        for (int sz = 1; sz < N; sz = sz + sz)
            for (int lo = 0; lo < N - sz; lo += sz + sz)
                merge(a, aux, lo, lo + sz, Math.min(lo + sz + sz, N));
    }

    public static void merge(int[] a, int aux[], int lo, int mid, int hi) {
        for (int i = lo; i < mid; i++) { // save a[lo..mid[ to aux
            aux[i] = a[i];
        }
        for (int i = lo, j = mid, k = lo; i < mid; k++) {
            if (j < hi && less(a[j], aux[i]))
                a[k] = a[j++];
            else
                a[k] = aux[i++];
        }
    }

这个版本非常简单,但在大型阵列上仍然不是很有效,因为每次传递都经过整个阵列,破坏了处理器缓存方案。使用一堆大小不断增加的排序子数组以增量方式执行自底向上合并会更有效。

【讨论】:

  • 对于 N 不是 2 的幂,如果没有我的回答中提到的修复,我看不到最后 N-sz 元素的排序位置。
  • @rcgldr:恐怕我不明白为什么需要您的修复:在内部循环中,如果 lo >= N - sz,则子数组 a[lo..N[ 已经排序因为它的大小是 sz,这是在之前的 pass 中排序的。在您的版本中,mergemid == hi 一起调用,这会导致左侧子数组的 2 个副本没有任何更改。
  • 我没有考虑到 OP 合并排序代码总是将数据复制到临时数组,然后合并回来。因此,正如您所评论的,如果在合并过程的末尾有一个子数组 a[lo..N] ,则不需要处理。在每次通过(或递归级别)改变方向的优化合并排序的情况下,需要将 a[lo...N] 或 aux[lo...N] 复制到另一个数组。我的回答的第三部分显示了这种情况。
  • @rcgldr:你从这个优化中得到了多少改进?在大型阵列上,我预计至少 30% 来自增量合并(深度优先),而不是 OP 的广度优先方法。
  • 为避免不必要的副本,自上而下可以根据递归级别更改合并方向,每次通过时自下而上。我用 n = 2^24 = 16777216 伪随机 64 位无符号整数进行了测试。自上而下比自下而上慢约 2%。请注意,自上而下推送和弹出大约 2·n 对索引,因此您可以将自上而下的时间复杂度视为 O(n·(1+log2(n)),自下而上视为 O(n·log2(n)) .对于n = 2^24,也就是O(25·n),与O(24·n)相比,相差4%左右,但由于栈在L1缓存中,开销较小。大多数稳定排序的库使用混合自底向上合并排序+插入排序。
【解决方案2】:

通过此更改,它可以在我的系统上运行。

            else if(less(aux[j], aux[i]))
                a[k] = aux[j++];             // fix  (aux)
            else
                a[k] = aux[i++];             // fix  (aux)

如果合并排序通过改变每次遍历的合并方向来避免复制步骤,如果在合并遍历结束时剩下一个运行,则需要复制它。这个答案的第 3 部分有一个例子。


当我使用带有随机值的较大数组(如 800 万个整数)进行测试时,less(...) 的使用会间歇性地使我的系统上的运行时间加倍。将 if(less(aux[j], aux[i])) 更改为 if(aux[j]


更高效的合并排序的示例代码,它避免了复制,除非有奇数次通过。这可以通过首先计算传递次数来避免,如果传递次数是奇数,则就地交换。这可以通过在初始传递中对 32 或 64 个元素的组使用插入排序扩展到更大的子组。

    public static void sort(int[] a) {
        int n = a.length;
        if(n < 2)
            return;
        int[] dst = new int[n];
        int[] src = a;
        int[] tmp;
        for(int sz = 1; sz < n; sz = sz+sz){
            int lo;
            int md;
            int hi = 0;
            while(hi < n){
                lo = hi;
                md = lo+sz;
                if(md >= n){            // if single run remaining, copy it
                    System.arraycopy(src, lo, dst, lo, n-lo);
                    break;
                }
                hi = md+sz;
                if(hi > n)
                    hi = n;
                merge(src, dst, lo, md, hi);
            }
            tmp = src;                  // swap references
            src = dst;                  //  to change direction of merge
            dst = tmp;
        }
        if(src != a)                    // copy back to a if needed
            System.arraycopy(src, 0, a, 0, n);
    }

    public static void merge(int[] src, int[] dst, int lo, int md, int hi) {
        int i = lo;
        int j = md;
        int k = lo;
        while(true){
            if(src[j]< src[i]){
                dst[k++] = src[j++];
                if(j < hi)
                    continue;
                System.arraycopy(src, i, dst, k, md-i);
                return;
            } else {
                dst[k++] = src[i++];
                if(i < md)
                    continue;
                System.arraycopy(src, j, dst, k, hi-j);
                return;
            }
        }
    }

【讨论】:

  • less 不被间歇性地使用,它被系统地用于所有元素的比较,这是低效的,但可能是为了使算法具有可替换排序顺序的通用性.
  • @chqrlieforyellowblockquotes - 问题是 800 万个整数的运行时间有时约为 1.2 秒,有时为 2.4 秒。如果我删除less,运行时间保持一致,大约为 1.2。秒。
  • 这可能是 JIT 编译器的副作用。鉴于 sort()merge() 方法对 int 数组进行操作的特殊性,使用不可替换的比较方法没有多大意义,因此确实建议去掉它。
【解决方案3】:

你可以试试这段代码:

import java.util.Arrays;

public class MergeSort
{
   public static void merge(double[] a, 
                            int iLeft, int iMiddle, int iRight, 
                            double[] tmp)
   {
      int i, j, k;

      i = iLeft;
      j = iMiddle;
      k = iLeft;

      while ( i < iMiddle || j < iRight )
      {
         if ( i < iMiddle && j < iRight )
         {  // Both array have elements
            if ( a[i] < a[j] )
               tmp[k++] = a[i++];
            else
               tmp[k++] = a[j++];
         }
         else if ( i == iMiddle )
            tmp[k++] = a[j++];     // a is empty
         else if ( j == iRight )
            tmp[k++] = a[i++];     // b is empty
      }

      /* =================================
         Copy tmp[] back to a[]
         ================================= */
      for ( i = iLeft; i < iRight; i++ )
         a[i] = tmp[i];
   }

   public static void sort(double[] a, double[] tmp)
   {
      int width;

      for ( width = 1; width < a.length; width = 2*width )
      {
         // Combine sections of array a of width "width"
         int i;

         for ( i = 0; i < a.length; i = i + 2*width )
         {
            int left, middle, right;

        left = i;
        middle = i + width;
        right  = i + 2*width;

            merge( a, left, middle, right, tmp );

         }

         System.out.println("After 1 iter: " + Arrays.toString(a) );
      }
   }
}

【讨论】:

  • 我不能,我需要遵循那个 API...我真的需要找到我的错误。
猜你喜欢
  • 2015-09-10
  • 2016-12-13
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2016-03-07
  • 2020-05-25
  • 2013-05-03
相关资源
最近更新 更多