【问题标题】:Scala recursion vs loop: performance and runtime considerationsScala 递归与循环:性能和运行时注意事项
【发布时间】:2013-02-28 14:56:25
【问题描述】:

我编写了一个简单的测试平台来测量三种阶乘实现的性能:基于循环、非尾递归和尾递归。

令我惊讶的是,性能最差的是循环的(«while» 预计会更高效,所以我提供了两者) 几乎是尾递归替代方案的两倍。

*回答:修复循环实现,避免 = 运算符,该运算符在 BigInt 中表现最差,因为它的内部 «循环» 变得如预期的那样最快

我遇到的另一个 «woodoo» 行为是 StackOverflow 没有针对相同输入系统地抛出的异常 非尾递归实现的情况。我可以绕过 StackOverlow 通过逐步调用越来越大的函数 values... 我觉得很疯狂 :) 答案:JVM 需要在启动过程中收敛,然后行为是连贯和系统的

这是代码:

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    var i = 1
    while (i <= n) {
      accumulator = i * accumulator
      i += 1
    }
    accumulator
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(n: Int, acc: Out): Out = n match {
      case _ if n == 1 => acc
      case _ => fac(n-1, n * acc)
    }

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(i: Int, acc: Out): Out = n match {
      case _ if i == n => n * acc
      case _ => fac(i+1, i * acc)
    }

    fac(1, 1)
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = false) =
      showOutput match {
        case true => printf("%s returned %s in %d ms\n", msg, data._2.toString, data._1)
        case false => printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    showOutput ("By for loop", measure(()=>calculateByForLoop(n)))
    showOutput ("By while loop", measure(()=>calculateByWhileLoop(n)))
    showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n)))
    showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n)))
    showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n)))
  }
}

以下是 sbt 控制台的一些输出(在 «while» 实施之前)

scala> example.Factorial.comparePerformance(10000)
By loop in 3 ns
By non-tail recursion in >>>>> StackOverflow!!!!!… see later!!!
........

scala> example.Factorial.comparePerformance(1000)
By loop in 3 ms
By non-tail recursion in 1 ms
By tail recursion in 4 ms

scala> example.Factorial.comparePerformance(5000)
By loop in 105 ms
By non-tail recursion in 27 ms
By tail recursion in 34 ms

scala> example.Factorial.comparePerformance(10000)
By loop in 236 ms
By non-tail recursion in 106 ms     >>>> Now works!!!
By tail recursion in 127 ms

scala> example.Factorial.comparePerformance(20000)
By loop in 977 ms
By non-tail recursion in 495 ms
By tail recursion in 564 ms

scala> example.Factorial.comparePerformance(30000)
By loop in 2285 ms
By non-tail recursion in 1183 ms
By tail recursion in 1281 ms

以下是 sbt 控制台的一些输出(在 «while» 实施之后)

scala> example.Factorial.comparePerformance(10000)
By for loop in 252 ms
By while loop in 246 ms
By non-tail recursion in 130 ms
By tail recursion in 136 ns

scala> example.Factorial.comparePerformance(20000)
By for loop in 984 ms
By while loop in 1091 ms
By non-tail recursion in 508 ms
By tail recursion in 560 ms

以下是 sbt 控制台的一些输出(在«向上»尾递归实施之后)世界恢复正常

scala> example.Factorial.comparePerformance(10000)
By for loop in 259 ms
By while loop in 229 ms
By non-tail recursion in 114 ms
By tail recursion in 119 ms
By tail recursion upward in 105 ms

scala> example.Factorial.comparePerformance(20000)
By for loop in 1053 ms
By while loop in 957 ms
By non-tail recursion in 513 ms
By tail recursion in 565 ms
By tail recursion upward in 470 ms

以下是在 «loops» 中修复 BigInt 乘法后 sbt 控制台的一些输出:世界完全正常

    scala> example.Factorial.comparePerformance(20000)
By for loop in 498 ms
By while loop in 502 ms
By non-tail recursion in 521 ms
By tail recursion in 611 ms
By tail recursion upward in 503 ms

BigInt 开销和我的 愚蠢 实现掩盖了预期的行为。

PS.:最后我应该把这篇文章重新命名为 «A lernt course on BigInts»

【问题讨论】:

  • 问题是什么?
  • "应该避免循环" 这是误导性的,在 scala 中,for 循环通常比等效的 while 循环慢得多。尾递归通常比非尾递归快,我会说它更慢,因为你在函数开始之前创建了闭包。
  • 谁能帮助我理解这种指向权威参考的行为。
  • @Score_Under:我没有看到尾递归实现中有任何闭包,你指的是哪一个?
  • Rex 解决了主要问题,但您应该研究 JVM 上的基准,因为您做错了。在 JVM 上正确地进行基准测试非常非常困难,并且很难从它们的结果中推断出任何有用的东西(适用于微基准的内容不适用于更大的上下文)。这里需要注意的是,在开始测试之前你不是垃圾收集,这势必会产生很大的变化。

标签: performance scala stack-overflow tail-recursion


【解决方案1】:

For 循环实际上并不完全是循环;它们用于范围内的理解。如果你真的想要一个循环,你需要使用while。 (实际上,我认为这里的 BigInt 乘法已经足够重量级了,所以应该没关系。但是你会注意到,如果你正在乘法 Ints。)

另外,您使用BigInt 让自己感到困惑。 BigInt 越大,乘法越慢。所以你的非尾循环计数 up 而你的尾递归循环计数 down 这意味着后者有更多的大数要相乘。

如果你解决了这两个问题,你会发现恢复了理智:循环和尾递归的速度相同,常规递归和for 都更慢。 (如果JVM优化使其等效,正则递归可能不会变慢)

(此外,堆栈溢出修复可能是因为 JVM 开始内联,并且可能使调用本身进行尾递归,或者将循环展开足够远以便您不再溢出。)

最后,你用 for 和 while 得到的结果很差,因为你是在右边而不是左边与小数相乘。事实证明,Java 的 BigInt 与左侧较小的数字相乘更快。

【讨论】:

  • 我很困惑。通过查看代码,我会说两者都倒计时。是否由于尾递归导致顺序发生了某种变化?
  • 我刚刚添加了一个«while» 实现,它的性能优于«for»...给我 1' 来集成您的 perl!
  • @bluenote10 - 正则递归首先计算最深的值。如所写,尾递归则相反。写几个术语,你会看到的。
  • 您对抛出 StackOverflow 异常时缺少确定性有何看法?...我仍然认为某种记忆化是在幕后工作的
  • @LordoftheGoo - 我将其添加到答案中。如果没有 JVM 在内联期间打印程序集,就很难确定。
【解决方案2】:

factorial(n) 的 Scala 静态方法(使用 scala 2.12.x,java-8 编码):

object Factorial {

  /*
   * For large N, it throws a stack overflow
   */
  def recursive(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      n * recursive(n - 1)
    }
  }

  /*
   * A tail recursive method is compiled to avoid stack overflow
   */
  @scala.annotation.tailrec
  def recursiveTail(n:BigInt, acc:BigInt = 1): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      acc
    } else {
      recursiveTail(n - 1, n * acc)
    }
  }

  /*
   * A while loop
   */
  def loop(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      var acc = 1
      var idx = 1
      while(idx <= n) {
        acc = idx * acc
        idx += 1
      }
      acc
    }
  }

}

规格:

class FactorialSpecs extends SpecHelper {

  private val smallInt = 10
  private val largeInt = 10000

  describe("Factorial.recursive") {
    it("return 1 for 0") {
      assert(Factorial.recursive(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursive(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursive(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursive(smallInt) == 3628800)
    }
    it("throws StackOverflow for large inputs") {
      intercept[java.lang.StackOverflowError] {
        Factorial.recursive(Int.MaxValue)
      }
    }
  }

  describe("Factorial.recursiveTail") {
    it("return 1 for 0") {
      assert(Factorial.recursiveTail(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursiveTail(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursiveTail(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursiveTail(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.recursiveTail(largeInt).isInstanceOf[BigInt])
    }
  }

  describe("Factorial.loop") {
    it("return 1 for 0") {
      assert(Factorial.loop(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.loop(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.loop(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.loop(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.loop(largeInt).isInstanceOf[BigInt])
    }
  }
}

基准测试:

import org.scalameter.api._

class BenchmarkFactorials extends Bench.OfflineReport {

  val gen: Gen[Int] = Gen.range("N")(1, 1000, 100) // scalastyle:ignore

  performance of "Factorial" in {
    measure method "loop" in {
      using(gen) in {
        n => Factorial.loop(n)
      }
    }
    measure method "recursive" in {
      using(gen) in {
        n => Factorial.recursive(n)
      }
    }
    measure method "recursiveTail" in {
      using(gen) in {
        n => Factorial.recursiveTail(n)
      }
    }
  }

}

基准测试结果(循环更快):

[info] Test group: Factorial.loop
[info] - Factorial.loop.Test-9 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.01 ms, ci = <0.00 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.01 ms, ci = <0.01 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.02 ms, ci = <0.02 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.03 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.03 ms, ci = <0.03 ms, 0.04 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.04 ms, ci = <0.04 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.05 ms, ci = <0.05 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.06 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.07 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursive
[info] - Factorial.recursive.Test-10 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.05 ms, ci = <0.01 ms, 0.09 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.07 ms, ci = <0.00 ms, 0.13 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.09 ms, ci = <0.01 ms, 0.18 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.10 ms, ci = <0.03 ms, 0.17 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.11 ms, ci = <0.08 ms, 0.15 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.13 ms, ci = <0.11 ms, 0.14 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.16 ms, ci = <0.13 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.21 ms, ci = <0.15 ms, 0.27 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursiveTail
[info] - Factorial.recursiveTail.Test-11 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.12 ms, ci = <0.05 ms, 0.20 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.16 ms, ci = <-0.03 ms, 0.34 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.12 ms, ci = <0.09 ms, 0.16 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.17 ms, ci = <0.15 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.23 ms, ci = <0.19 ms, 0.26 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.25 ms, ci = <0.18 ms, 0.32 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.28 ms, ci = <0.21 ms, 0.36 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.32 ms, ci = <0.17 ms, 0.46 ms>, significance = 1.0E-10)

【讨论】:

    【解决方案3】:

    我知道大家已经回答了这个问题,但我想我可能会添加这样一个优化:如果将模式匹配转换为简单的 if 语句,它可以加快尾递归。

    final object Factorial {
      type Out = BigInt
    
      def calculateByRecursion(n: Int): Out = {
        require(n>0, "n must be positive")
    
        n match {
          case _ if n == 1 => return 1
          case _ => return n * calculateByRecursion(n-1)
        }
      }
    
      def calculateByForLoop(n: Int): Out = {
        require(n>0, "n must be positive")
    
        var accumulator: Out = 1
        for (i <- 1 to n)
          accumulator = i * accumulator
        accumulator
      }
    
      def calculateByWhileLoop(n: Int): Out = {
        require(n>0, "n must be positive")
    
        var acc: Out = 1
        var i = 1
        while (i <= n) {
          acc = i * acc
          i += 1
        }
        acc
      }
    
      def calculateByTailRecursion(n: Int): Out = {
        require(n>0, "n must be positive")
    
        @annotation.tailrec
        def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)
    
        fac(n, 1)
      }
    
      def calculateByTailRecursionUpward(n: Int): Out = {
        require(n>0, "n must be positive")
    
        @annotation.tailrec
        def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)
    
        fac(1, 1)
      }
    
      def attempt(f: ()=>Unit): Boolean = {
        try {
            f()
            true
        } catch {
            case _: Throwable =>
                println(" <<<<< Failed...")
                false
        }
      }
    
      def comparePerformance(n: Int) {
        def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
          showOutput match {
            case true =>
                val res = data._2.toString
                val pref = res.substring(0,5)
                val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
                val suff = res.substring(res.length-5)
                printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
            case false => 
                printf("%s in %d ms\n", msg, data._1)
        }
        def measure[A](f:()=>A): (Long, A) = {
          val start = System.currentTimeMillis
          val o = f()
          (System.currentTimeMillis - start, o)
        }
        attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
        attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
        attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
        attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
        attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
      }
    }
    

    我的结果:

    scala> Factorial.comparePerformance(20000)
    By for loop returned 18192...5708616582...00000 in 179 ms
    By while loop returned 18192...5708616582...00000 in 159 ms
    By non-tail recursion <<<<< Failed...
    By tail recursion returned 18192...5708616582...00000 in 169 ms
    By tail recursion upward returned 18192...5708616582...00000 in 174 ms
    
    By for loop returned 18192...5708616582...00000 in 212 ms
    By while loop returned 18192...5708616582...00000 in 156 ms
    By non-tail recursion returned 18192...5708616582...00000 in 155 ms
    By tail recursion returned 18192...5708616582...00000 in 166 ms
    By tail recursion upward returned 18192...5708616582...00000 in 137 ms
    
    scala> Factorial.comparePerformance(200000)
    By for loop returned 14202...0169293868...00000 in 17467 ms
    By while loop returned 14202...0169293868...00000 in 17303 ms
    By non-tail recursion <<<<< Failed...
    By tail recursion returned 14202...0169293868...00000 in 18477 ms
    By tail recursion upward returned 14202...0169293868...00000 in 17188 ms
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2012-08-22
      • 2010-11-01
      • 1970-01-01
      • 2012-02-03
      • 2011-07-09
      • 2017-01-24
      • 1970-01-01
      相关资源
      最近更新 更多