【问题标题】:Parallel Merge Sort in ScalaScala中的并行合并排序
【发布时间】:2015-10-25 08:35:41
【问题描述】:

我一直在尝试在 Scala 中实现并行合并排序。但是在 8 核的情况下,使用.sorted 仍然快两倍左右。

编辑:

我重写了大部分代码以最小化对象创建。现在它的运行速度与.sorted 一样快

1.2M 整数的输入文件:

  • 1.333580 秒(我的实现)
  • 1.439293 秒 (.sorted)

我应该如何并行化这个?

新实施

object Mergesort extends App
{

//=====================================================================================================================
// UTILITY
  implicit object comp extends Ordering[Any] {
    def compare(a: Any, b: Any) = {
      (a, b) match {
        case (a: Int, b: Int)       => a compare b
        case (a: String, b: String) => a compare b
        case _                      => 0
      }
    }
  }

//=====================================================================================================================
// MERGESORT

  val THRESHOLD = 30

  def inssort[A](a: Array[A], left: Int, right: Int): Array[A] = {
    for (i <- (left+1) until right) {
      var j = i
      val item = a(j)
      while (j > left && comp.lt(item,a(j-1))) {
        a(j) = a(j-1)
        j -= 1
      }
      a(j) = item
    }
    a
  }

  def mergesort_merge[A](a: Array[A], temp: Array[A], left: Int, right: Int, mid: Int) : Array[A] = {
    var i = left
    var j = right
    while (i < mid) { temp(i) = a(i);   i+=1;       }
    while (j > mid) { temp(i) = a(j-1); i+=1; j-=1; }

    i = left
    j = right-1
    var k = left
    while (k < right) {
      if (comp.lt(temp(i), temp(j))) { a(k) = temp(i); i+=1; k+=1; }
      else                           { a(k) = temp(j); j-=1; k+=1; }
    }
    a
  }

  def mergesort_split[A](a: Array[A], temp: Array[A], left: Int, right: Int): Array[A] = {
    if (right-left == 1) a

    if ((right-left) > THRESHOLD) {
      val mid = (left+right)/2
      mergesort_split(a, temp, left, mid)
      mergesort_split(a, temp, mid, right)
      mergesort_merge(a, temp, left, right, mid)
    }
    else
      inssort(a, left, right)
  }

  def mergesort[A: ClassTag](a: Array[A]): Array[A] = {
    val temp = new Array[A](a.size)
    mergesort_split(a, temp, 0, a.size)
  }

以前的实施

1.2M 整数的输入文件:

  • 4.269937 秒(我的实现)
  • 1.831767 秒 (.sorted)

有哪些技巧可以让它更快更干净?

object Mergesort extends App
{

//=====================================================================================================================
// UTILITY

  val StartNano = System.nanoTime
  def dbg(msg: String) = println("%05d DBG ".format(((System.nanoTime - StartNano)/1e6).toInt) + msg)
  def time[T](work: =>T) = {
    val start = System.nanoTime
    val res = work
    println("%f seconds".format((System.nanoTime - start)/1e9))
    res
  }

  implicit object comp extends Ordering[Any] {
    def compare(a: Any, b: Any) = {
      (a, b) match {
        case (a: Int, b: Int)       => a compare b
        case (a: String, b: String) => a compare b
        case _                      => 0
      }
    }
  }

//=====================================================================================================================
// MERGESORT

  def merge[A](left: List[A], right: List[A]): Stream[A] = (left, right) match {
    case (x :: xs, y :: ys) if comp.lteq(x, y) => x #:: merge(xs, right)
    case (x :: xs, y :: ys) => y #:: merge(left, ys)
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }

  def sort[A](input: List[A], length: Int): List[A] = {
    if (length < 100) return input.sortWith(comp.lt)
    input match {
      case Nil | List(_) => input
      case _ =>
        val middle = length / 2
        val (left, right) = input splitAt middle
        merge(sort(left, middle), sort(right, middle + length%2)).toList
    }
  }

  def msort[A](input: List[A]): List[A] = sort(input, input.length)

//=====================================================================================================================
// PARALLELIZATION

  //val cores = Runtime.getRuntime.availableProcessors
  //dbg("Detected %d cores.".format(cores))
  //lazy implicit val ec = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(cores))

  def futuremerge[A](fa: Future[List[A]], fb: Future[List[A]])(implicit order: Ordering[A], ec: ExecutionContext) =
  {
    for {
      a <- fa
      b <- fb
    } yield merge(a, b).toList
  }

  def parallel_msort[A](input: List[A], length: Int)(implicit order: Ordering[A]): Future[List[A]] = {
    val middle = length / 2
    val (left, right) = input splitAt middle

    if(length > 500) {
      val fl = parallel_msort(left, middle)
      val fr = parallel_msort(right, middle + length%2)
      futuremerge(fl, fr)
    }
    else {
      Future(msort(input))
    }
  }

//=====================================================================================================================
// MAIN

  val results = time({
    val src = Source.fromFile("in.txt").getLines
    val header = src.next.split(" ").toVector
    val lines = if (header(0) == "i") src.map(_.toInt).toList else src.toList
    val f = parallel_msort(lines, lines.length)
    Await.result(f, concurrent.duration.Duration.Inf)
  })

  println("Sorted as comparison...")
  val sorted_src = Source.fromFile(input_folder+"in.txt").getLines
  sorted_src.next
  time(sorted_src.toList.sorted)

  val writer = new PrintWriter("out.txt", "UTF-8")
  try writer.print(results.mkString("\n"))
  finally writer.close
}

【问题讨论】:

  • 您的数据集有多大?你执行了多少次?由于您是性能测试,您是否进行了多次运行,以便 JVM 有时间优化和预热代码?您是否尝试过使用scala.testing.Benchmark 运行?
  • 我主要使用包含几十万行到几百万行的输入文件进行测试。 (上面运行的示例是 1.2M)我现在只运行单次运行,因为我的实现和 .sorted 运行时间之间仍然存在很大差异,多次运行不会有太大变化。
  • 如果您在 JVM 中运行,则需要多次运行才能克服线程启动时间。
  • 查看此问题和回复:stackoverflow.com/q/504103/7507
  • .sorted 是否并行运行? .sorted 是否有可能绕过JVM?当数据适合每个核心的缓存时,似乎并行排序会很有帮助。一旦超出此范围,所有内核都将竞争相同的内存总线。

标签: multithreading scala sorting mergesort


【解决方案1】:

我的回答可能会有点长,但我希望它对你和我都有用。

所以,第一个问题是:“scala 是如何对 List 进行排序的?”让我们看看 scala repo 中的代码!

  def sorted[B >: A](implicit ord: Ordering[B]): Repr = {
    val len = this.length
    val b = newBuilder
    if (len == 1) b ++= this
    else if (len > 1) {
      b.sizeHint(len)
      val arr = new Array[AnyRef](len)  // Previously used ArraySeq for more compact but slower code
      var i = 0
      for (x <- this) {
        arr(i) = x.asInstanceOf[AnyRef]
        i += 1
      }
      java.util.Arrays.sort(arr, ord.asInstanceOf[Ordering[Object]])
      i = 0
      while (i < arr.length) {
        b += arr(i).asInstanceOf[A]
        i += 1
      }
    }
    b.result()
  }

那么这里到底发生了什么?长话短说:用java。其他一切都只是尺寸调整和铸造。基本上这是定义它的行:

java.util.Arrays.sort(arr, ord.asInstanceOf[Ordering[Object]])

让我们更深入地了解 JDK 源代码:

public static <T> void sort(T[] a, Comparator<? super T> c) {
    if (c == null) {
        sort(a);
    } else {
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a, c);
        else
            TimSort.sort(a, 0, a.length, c, null, 0, 0);
    }
}

legacyMergeSort 只不过是合并排序算法的单线程实现。

下一个问题是:“什么是 TimSort.sort,我们什么时候使用它?”

据我所知,此属性的默认值为 false,这导致我们使用 TimSort.sort 算法。描述可以在here找到。为什么更好?根据 JDK 源中的 cmets 进行归并排序的比较少。

而且你应该知道它都是单线程的,所以这里没有并行化。

第三个问题,“你的代码”

  1. 您创建了太多对象。在性能方面,突变(可悲)是你的朋友。
  2. 过早的优化是万恶之源——Donald Knuth。在进行任何优化(如并行)之前,请尝试实现单线程版本并比较结果。
  3. 使用 JMH 之类的东西来测试代码的性能。
  4. 如果您想获得最佳性能,您可能不应该使用 Stream 类,因为它会进行额外的缓存。

我故意没有给你回答 “可以在这里找到 scala 中的超快速合并排序”,而只是给你一些适用于你的代码和编码实践的提示。

希望对你有帮助。

【讨论】:

猜你喜欢
  • 1970-01-01
  • 2014-06-18
  • 1970-01-01
  • 1970-01-01
  • 2012-11-28
  • 2012-01-16
  • 1970-01-01
  • 1970-01-01
  • 2013-08-04
相关资源
最近更新 更多