【问题标题】:Is there any faster alternative to stats:uniroot function in R?有没有比 R 中的 stats:uniroot 函数更快的替代方法?
【发布时间】:2021-12-19 22:35:15
【问题描述】:

我在 data.table 中的一百万行上运行 stats::uniroot 函数。这是一个玩具示例 -

library(data.table)
cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

n <- 50000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

print(system.time(
dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
))

在上面的代码中,50,000 行所用的时间接近 8 秒。

有没有比uniroot 函数更快的替代方法可以大大减少这个时间?

【问题讨论】:

    标签: r data.table


    【解决方案1】:

    160 秒 (1e6/5e4 * 8) 对我来说一百万行听起来还不错(尽管您的实际功能可能比您在这里使用的 froot 慢得多?)。这可以简单地并行化,在不同的内核上运行单独的块(参见例如对this question 的回答)。

    你有多需要 extendInt ?如果我制作 uniroot() 函数的黑客版本,只有其核心功能,没有参数测试逻辑等,我可以将速度提高三倍。但是,如果你的目标函数是,你的速度增益将不那么令人印象深刻比您在此处给出的示例要慢得多;如果是这种情况,您应该专注于加速您的目标函数(我尝试通过 Rcpp 在 C++ 中重新编码您的 froot,但在这种情况下它并没有真正帮助 - 该函数非常简单,函数调用开销花费了大部分时间...)

    为了便于基准测试,我只使用了 5000 行:

    n <- 5000
    u <- -log(runif(n))
    a <- 1/2
    b <- 1
    dt = data.table(u = u, a = a, b = b)
    

    最小功能:

    uu <- function(f, lower, upper, tol = 1e-8, maxiter =1000L, ...) {
      f.lower <- f(lower, ...)
      f.upper <- f(upper, ...)
      val <- .External2(stats:::C_zeroin2, function(arg) f(arg, ...),
                        lower, upper, f.lower, f.upper, tol, as.integer(maxiter))
      return(val[1])
    }
    

    检查我们是否得到相同的结果:

    identical(uniroot(froot, u = 3.242, a=0.5, b=1, interval = c(0.01,100))$root,
              uu(froot, u = 3.242, a=0.5, b=1, lower = 0.01, upper = 100))
    ## TRUE
    

    基准测试包;将评估包装在函数中以实现紧凑性

    library(rbenchmark)
    f1 <- function() {
      dt[, c := uniroot(froot_cpp, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
    }
    f2 <- function() {
      dt[, c := uu(froot, u=u, a=a, b=b, lower = 0.01, upper = 100), by = u]
    }
    bb <- benchmark(f1(), f2(), 
        columns =c("test", "replications", "elapsed", "relative"))
    

    结果:

      test replications elapsed relative
    1 f1()          100  34.616    3.074
    2 f2()          100  11.261    1.000
    

    【讨论】:

    • 感谢 Ben,最小函数将我的代码性能提高了 20%,这是一个巨大的进步。
    • 太棒了。在这种情况下,您可能会从加速目标函数中获得更多的里程……(如果您有足够的内存与内核一起使用,那么拆分/并行化将为您带来与内核一样多的改进)。
    • 我有 8 个核心,data.table 默认使用 4 个核心。将内核从 4 个增加到 8 个确实将性能提高了 2%,但这并不值得,因为我希望保留一些内核可用于其他任务。使用 functionmcapply 也不会提高性能。我会尽量优化功能。
    • rootSolve::uniroot.all 函数将性能比 stats::uniroot 函数提高了 15%,但准确性受到影响。您的最小功能仍然更快更准确。
    【解决方案2】:

    请注意,所示函数的逆函数可以显式计算为

    f2 <- function(x) (b^a * x / b)^(1/a)
    a <- 1/2
    b <- 1
    all.equal(f(.5), f2(.5))  # f defined below using uniroot
    ## [1] TRUE
    

    但是,假设实际上您有一个更复杂的函数,我们可以使用 Chebyshev 近似来得到它的近似值。请注意,a 和 b 是问题中的常量,因此我们还假设以下情况,即 f 使用在全局环境中设置的常量 a 和 b。下面的代码运行速度比基准问题中的代码快近 100 倍,具有 9 次多项式,并且在 uniroot 给出的答案的 1e-4 范围内。如果您需要更高的准确性,请使用更高的度数。

    library(data.table)
    library(pracma)
    set.seed(123)
    
    cumhaz <- function(t, a, b) b * (t/b)^a
    froot <- function(x, u, a, b) cumhaz(x, a, b) - u
    
    n <- 5000
    u <- -log(runif(n))
    a <- 1/2
    b <- 1
    dt = data.table(u = u, a = a, b = b)
    
    dt2 <- copy(dt)
    f <- function(u) {
      uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root
    }
    
    library(microbenchmark)
    microbenchmark(times = 10,
      orig = dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u],
      cheb = dt2[, c := chebApprox(u, Vectorize(f), min(u), max(u), 9)]
    )
    ## Unit: milliseconds
    ##  expr      min       lq      mean    median       uq      max neval cld
    ##  orig 943.5323 948.9321 961.00361 958.91970 972.6308 982.0060    10   b
    ##  cheb   9.3752   9.7513  10.67386  10.02555  10.3411  16.9475    10  a 
    
    max(abs(dt$c - dt2$c))
    ## [1] 8.081021e-05
    

    【讨论】:

    • 谢谢,格洛腾迪克。不幸的是,变量ab 在我的原始函数中不是常量。有没有办法像在uniroot 函数中那样将它们传递给 chebApprox 函数?
    • 你可能会弄乱Vectorizevectorize.args参数
    • 由于ab 不是常量,我尝试了以下方法,但我在变量c 中得到NaN 值。 vec_f &lt;- Vectorize(f, vectorize.args = c("u", "a", "b")) dt2[, c := chebApprox(x = u, fun = function(x) vec_f(u = x, a, b), a = min(u), b = max(u), n = 9), by = u]。我一定是做错了什么,请您指出来吗?
    • 如果问题不能简化为 1 个或多个 1d 问题,则 chebApprox 不适用。例如,有少量的 a、b 组合,然后 chebApprox 可以分别应用于每个组合。如果您的函数有特殊功能,那么可能还有其他方法,但我们没有这方面的任何信息。
    【解决方案3】:

    对于确切的问题有很好的答案,但有一些关于一般 R 实践的注释。

    在顺序无关紧要时使用

    在 OP 中,我们使用 by = u 以便每一行一次运行一个。这是低效的! data.table 将为您排序,确定分组,并且由于它们是真正的非常随机的数字,因此最终分组与行一样多。

    相反,我们可以使用Map()mapply() 来遍历行,这将提高性能。请注意,尚不清楚 ab 是否真的因行而异 - 如果它们确实是常量,我们可能希望将它们从 data.table 中取出并作为常量传递。

    uniroot2 = function(...) uniroot(...)$root ## helper function
    dt[, c2 := mapply(uniroot2, u, a,b,
                      MoreArgs = list (f = froot,
                                       interval = c(0.01, 10),
                                       extendInt = 'yes'))]
    
    ## for n = 5000
    
    ## # A tibble: 2 x 13
    ##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    ##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
    ##  1 OP           1.17s   1.17s     0.851     170KB     2.55     1     3      1.17s
    ##  2 no_by      857.2ms 857.2ms     1.17      214KB     3.50     1     3    857.2ms
    ##
    ## Warning message:
    ## Some expressions had a GC in every iteration; so filtering is disabled. 
    

    注意,一旦我们在mapply 中设置了它,使用future.apply::future_mapply() 来并行化我们的调用就很简单了。这比我笔记本电脑上的上述 no_by 示例快 2.5 倍。

    library(future.apply)
    plan(multisession)
    dt[, c3 := future_mapply(uniroot2, u, a,b,
                      MoreArgs = list (f = froot,
                                       interval = c(0.01, 10),
                                       extendInt = 'yes')
                      , future.globals = "cumhaz")] ## see next section for how we could remove this
    

    函数调用需要时间

    在您的示例中,您将两个函数定义为:

    cumhaz <- function(t, a, b) b * (t/b)^a
    froot <- function(x, u, a, b) cumhaz(x, a, b) - u
    

    当性能是一个问题并且简化很简单时,您可能想要简化。

    froot2 = function(x, u, a, b) b * (x / b) ^ a - u
    

    超过一百万个循环,对cumhaz() 的额外调用加起来:

    x = 2.5; u = 1.5; a = 0.5; b = 1 
    bench::mark(froot_rep = for (i in 1:1e6) {froot(x=x, u=u, a=a, b=b)},
                froot2_rep = for (i in 1:1e6) {froot2(x=x, u=u, a=a, b=b)})
    
    ## # A tibble: 2 x 13
    ##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    ##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
    ## 1 froot_rep    4.74s   4.74s     0.211    13.8KB     3.38     1    16      4.74s
    ## 2 froot2_rep   3.17s   3.17s     0.315    13.8KB     2.84     1     9      3.17s
    ##
    ## Warning message:
    ## Some expressions had a GC in every iteration; so filtering is disabled. 
    

    因为uniroot 将进一步增加调用次数,默认最大迭代次数为 1,000!这意味着cumhaz() 在优化过程中会花费我们 1.5 到 1,500 秒之间的时间。作为@G。格洛腾迪克指出,有时我们实际上可以直接求解并使用直接向量化的方法,而不是依赖unirootoptimize

    【讨论】:

    • 感谢科尔的建议。使用分组进行行迭代肯定是没有意义的。我会试试你的建议。
    • 在我原来的函数中,变量ab 不是常量。我将您的建议应用到我的代码中,结果如下。 1.) 将较小的函数合并在一起,将性能提高了大约 5%,感谢这一点。 2.) 使用mapply 而不是分组by 对性能没有影响。 3.) 使用并行性降低了 5% 的性能。
    • 感谢分享。发现 2 和 3 有点令人惊讶。如果你想扩展你的原始帖子,看看真正的实现是什么会很有趣。至于by,那么你可能想做by = seq_len(nrow(dt))。这无济于事,但效率更高。
    • 使用by = seq_len(nrow(dt)) 将性能提高了大约 2%,这可能是由于减少了订购开销。谢谢!
    猜你喜欢
    • 2017-07-19
    • 2018-12-05
    • 1970-01-01
    • 1970-01-01
    • 2017-03-18
    • 1970-01-01
    • 2018-01-12
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多