【问题标题】:Tail recursive Quicksort continuation-style in ScalaScala中的尾递归快速排序延续样式
【发布时间】:2017-01-26 19:51:25
【问题描述】:

我正在尝试在 Scala 中编写一个尾递归快速排序,它通过建立一个延续来工作,而不使用蹦床。到目前为止,我有以下内容:

object QuickSort {

  def sort[A: Ordering](toSort: Seq[A]): Seq[A] = {
    val ordering = implicitly[Ordering[A]]
    import ordering._

    @scala.annotation.tailrec
    def step(list: Seq[A], conts: List[Seq[A] => Seq[A]]): Seq[A] = list match {
      case s if s.length <= 1 => conts.foldLeft(s) { case (acc, next) => next(acc) }
      case Seq(h, tail @ _*) => {
        val (less, greater) = tail.partition(_ < h)
        step(less, { sortedLess: Seq[A] =>
            /*
            Can't use 

            step(greater, sortedGreater => (sortedLess :+ h) ++ sortedGreater)

            and keep the tailrec annotation
           */
          (sortedLess :+ h) ++ sort(greater)
        } +: conts)
      }
    }

    step(toSort, Nil)
  }

}

Click for ScalaFiddle

在我的计算机上,上述实现适用于至少 4000000 个元素的随机序列,但我对此表示怀疑。具体来说,我想知道:

  1. 它是堆栈安全的吗?我们可以通过查看代码来判断吗?它使用@tailrec 编译,但对sort(greater) 的调用似乎有点可疑。
  2. 如果 (1) 的答案是“否”,是否可以在 Scala 中以 CPS 样式编写尾递归快速排序,即不使用蹦床?怎么样?

为了清楚起见,我查看了this related question,它讨论了如何使用蹦床(我知道如何使用)或您自己的显式堆栈来实现尾递归快速排序,但我特别想知道是否和如何以不同的方式完成。

【问题讨论】:

    标签: scala sorting recursion tail-recursion continuations


    【解决方案1】:

    我决定使用 JVisualVM 来查看我在问题中的实现的调用树,并发现由于 ++ step(greater) 调用,它正在占用堆栈。我认为很难达到堆栈溢出的地步,因为列表每次都被分成两半,较小的一半以尾递归、堆栈安全的方式进行尾递归排序。

    经过一番思考,我想出了以下修改后的解决方案(试试看here

    object QuickSort {
    
      def sort[A: Ordering](toSort: Seq[A]): Seq[A] = {
        val ordering = implicitly[Ordering[A]]
        import ordering._
    
        // Aliasing allows us to be tail-recursive
        def step2(list: Seq[A], conts: Vector[Seq[A] => Seq[A]]): Seq[A] = step(list, conts)
    
        @scala.annotation.tailrec
        def step(list: Seq[A], conts: Vector[Seq[A] => Seq[A]]): Seq[A] = list match {
          case s if s.length <= 1 => conts.foldLeft(s) { case (acc, next) => next(acc) }
          case Seq(h, tail @ _*) => {
            val (less, greater) = tail.partition(_ < h)
            val nextConts: Vector[Seq[A] => Seq[A]] =
              { sortedLess: Seq[A] =>
                sortedLess :+ h
              } +: { appendedLess: Seq[A] =>
                step2(greater, Vector({ sortedGreater => appendedLess ++ sortedGreater }))
              } +: conts
            step(less, nextConts)
          }
        }
        step(toSort, Vector.empty)
      }
    
    }
    

    主要区别是:

    • step 使用step2 别名来保持@tailrec 注释的有效性。
    • 我们没有在延续中调用step(greater) 对较小的分区进行排序,而是在conts 累加器中添加另一个延续,将已排序的较少分区附加到已排序的较大分区。我想你可能会说这个累加器只是堆上的一个堆栈..

    有趣的是,这个解决方案非常快,击败了linked question 中的 Scalaz 蹦床解决方案。与上面的半栈解决方案相比,它在对 100 万个元素进行排序时慢了大约 30 ns,但在误差范围内。

    [info] Benchmark                             (sortLength)  Mode  Cnt     Score    Error  Units
    [info] SortBenchmarks.sort                            100  avgt   30     0.034 ±  0.001  ms/op
    [info] SortBenchmarks.sort                          10000  avgt   30     6.258 ±  0.072  ms/op
    [info] SortBenchmarks.sort                        1000000  avgt   30  1016.849 ± 23.572  ms/op
    [info] SortBenchmarks.scalazSort                      100  avgt   30     0.070 ±  0.001  ms/op
    [info] SortBenchmarks.scalazSort                    10000  avgt   30    10.426 ±  0.092  ms/op
    [info] SortBenchmarks.scalazSort                  1000000  avgt   30  1635.693 ± 68.068  ms/op
    

    【讨论】:

      【解决方案2】:

      不,您的代码不是堆栈安全的。 sort 调用 stepstep 在很大程度上再次调用 sort,因此它不是堆栈安全的。

      要做cps,让我们从正常形式开始:

      def sort(list: Seq[A]): Seq[A] = list match {
        case s if s.length <= 1 => s
        case Seq(h, tail @ _*) => {
          val (less, greater) = tail.partition(_ < h)
          val l = sort(less)
          val g = sort(greater)
          (l :+ Seq(h)) ++ g
        }
      }
      

      然后翻译成cps,很直接:

      def sort(list: Seq[A], cont: Seq[A] => Unit): Unit = list match {
        case s if s.length <= 1 => cont(s)
        case Seq(h, tail @ _*) => {
          val (less, greater) = tail.partition(_ < h)
          sort(less, { l =>
            sort(greater, { g => 
              cont((l :+ Seq(h)) ++ g)
            })
          })
        }
      }
      

      注意:

      • CPS 函数总是返回Unit
      • 继续返回Unit
      • 每个递归调用都变成了对 self 的调用,并在 continue 中包含了保持语句。
      • 返回成为继续调用

      最后,将其包装成正常形式:

      def quicksort(list: Seq[A]): Seq[A] = {
        var result
        sort(list, { r => result = r })
        result
      }
      

      注意:CPS 转换使每个函数都进行尾调用(NOT tail-rec),因为 scala 不支持尾调用优化,因此您需要手动进行尾调用优化:

      trait TCF[T] {
        def result: Option[T]
        def apply(): TCF[T]
      }
      private def tco[T](f: => TCF[T]): TCF[T] = new TCF[T] {
        def result = None
        def apply() = f
      }
      
      def quicksort[A: Ordering](list: Seq[A]): Seq[A] = {
        case class Result(r: Seq[A]) extends Exception
        Iterator.iterate(sort(list, { r: Seq[A] =>
          new TCF[Seq[A]] {
            def result = Some(r)
            def apply() = throw new RuntimeException("unreachable")
          }
        }))(c => c()).dropWhile(_.result == None).next().result.get
      }
      
      private def sort[A: Ordering](list: Seq[A], cont: Seq[A] => TCF[Seq[A]]): TCF[Seq[A]] = {
        val ordering = implicitly[Ordering[A]]
        import ordering._
        list match {
          case s if s.length <= 1 => tco(cont(s))
          case Seq(h, tail@_*) => {
            val (less, greater) = tail.partition(_ < h)
            tco(sort(less, { l: Seq[A] =>
              tco(sort(greater, { g: Seq[A] =>
                tco(cont((l :+ h) ++ g))
              }))
            }))
          }
        }
      }
      

      试试here

      【讨论】:

      • 感谢您的回答。经过几次修复后,我设法让它编译(请参阅scalafiddle.io/sf/OzInX1U/2),但不幸的是,注意到它既不是堆栈安全的(在我的计算机上,当给出长度> 3000 的 Seq 时它会死掉;在 ScalaJS 中,这可能是依赖于浏览器),也不是尾递归(排序不在尾位置;我在解决方案中遇到的问题相同)。另外,如果可能的话,我希望解决方案是完全不可变的(引用和数据结构)。
      • @lloydmeta 我不确定 CPS 是否确保堆栈安全,但如果您需要堆栈安全,您可以编织出延续,它使用与蹦床类似的概念。
      • @lloydmeta CPS 转换使每个函数都进行尾调用(不是 tail-rec),因为 scala 不支持尾调用优化,所以您需要手动进行尾调用优化
      • 感谢您更新代码!它现在可以处理比以前更多的元素。顺便说一句,它看起来像您的原始代码堆栈溢出,因为它在遍历树以构建延续时正在执行非尾递归调用;您的新代码通过蹦床使调用“懒惰”来解决这个问题。不过你是对的,CPS 本身并不能确保 Scala 中的堆栈安全,因为如果每一层都简单地调用前一层(续),将最后一个值传递给延续可能会导致堆栈溢出。
      • (续)我已经更新了我的问题以通过累积延续并在我们到达终点时折叠它们来解决这个问题;见scalafiddle.io/sf/xh0CMpu/1
      【解决方案3】:
      1. 您的代码是尾递归的,因此应该是堆栈安全的。对sort(greater) 的调用停在延续中,它存在于堆而不是堆栈上。给定一个足够大的错误形状问题,您可能会破坏堆,但这比破坏堆栈要多得多。

      【讨论】:

      • 啊,是的,我知道我正在用堆栈换堆,并且考虑到一个足够大的问题,这个 CPS 技巧(以及相关的 trampoline/free monad 变体)会毁掉我的堆。我直观地猜测,sort 调用在堆上是安全的这一事实是有道理的;但我对确切原因的理解有点模糊和手波。
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2017-05-25
      • 1970-01-01
      • 1970-01-01
      • 2020-08-14
      • 1970-01-01
      • 2017-03-25
      • 2021-06-26
      相关资源
      最近更新 更多