【问题标题】:Looking for faster way to implement logSumExp across multidimensional array寻找跨多维数组实现 logSumExp 的更快方法
【发布时间】:2020-07-18 03:42:32
【问题描述】:

我正在编写的一些 R 代码中有一行非常慢。它使用 apply 命令将 logSumExp 应用于 4 维数组。我想知道有没有办法加快速度!

Reprex:(这可能需要 10 秒或更长时间才能运行)

library(microbenchmark)
library(matrixStats)

array4d <- array( runif(5*500*50*5 ,-1,0),
                  dim = c(5, 500, 50, 5) )
microbenchmark(
    result <- apply(array4d, c(1,2,3), logSumExp)
)

任何建议表示赞赏!

【问题讨论】:

    标签: r multidimensional-array


    【解决方案1】:

    rowSums 是 apply 的一个不太通用的版本,它在加法时针对速度进行了优化,因此可以用来加快计算速度。如果保持NANaN 之间的计算差异很重要,请注意帮助文件?rowSums 中的警告。

    library(microbenchmark)
    library(matrixStats)
    
    array4d <- array( runif(5*500*50*5 ,-1,0),
                      dim = c(5, 500, 50, 5) )
    microbenchmark(
      result <- apply(array4d, c(1,2,3), logSumExp),
      result2 <- log(rowSums(exp(array4d), dims=3))
    )
    
    
    # Unit: milliseconds
    #                                            expr      min       lq      mean    median        uq      max neval
    # result <- apply(array4d, c(1, 2, 3), logSumExp) 249.4757 274.8227 305.24680 297.30245 328.90610 405.5038   100
    # result2 <- log(rowSums(exp(array4d), dims = 3))  31.8783  32.7493  35.20605  33.01965  33.45205 133.3257   100
    
    all.equal(result, result2)
    
    #TRUE
    

    这使我的计算机速度提高了 9 倍

    【讨论】:

    • 确实,这在我的机器上实现了 37 倍的加速!但是,我担心精确度,它为 all.equal() 提供 TRUE,但对于 identical() 则不然。我试图实现此答案的变体,但 rowLogSumExps 的行为方式与 rowMeans 不同,因此我无法使其工作:stackoverflow.com/a/18633390/2498193
    • @user2498193 通常,由于机器精度和舍入,您不会期望两种不同的计算方法会给出相同的答案。 range(result - result2) 给出大约 10^-16 的数字,所以我建议差异可以忽略不计
    • Obligatory link 为什么数字不完全相同
    • 我有另一个问题,如果您有时间 (stackoverflow.com/questions/62894184/…),我认为值得跟进的问题 - 但如果没有,谢谢您迄今为止的帮助!
    • 唉!这又回来咬我了-极端数据会产生无穷大,而原始数据不会?
    【解决方案2】:

    @Miff 的另一个很好的解决方案是导致我的代码在某些数据集上崩溃,因为正在生成无穷大,我最终发现这是由于下溢问题,可以通过使用“logSumExp 技巧”来避免:https://www.xarg.org/2016/06/the-log-sum-exp-trick-in-machine-learning/

    从@Miff 的代码和R apply() 函数中汲取灵感,我创建了一个新函数,可以在避免下溢问题的同时提供更快的计算。然而,没有@Miff 的解决方案那么快。发帖以防对他人有帮助

    apply_logSumExp <- function (X) {
        MARGIN <- c(1, 2, 3) # fixing the margins as have not tested other dims
        dl <- length(dim(X)) # get length of dim
        d <- dim(X) # get dim
        dn <- dimnames(X) # get dimnames
        ds <- seq_len(dl) # makes sequences of length of dims
        d.call <- d[-MARGIN]    # gets index of dim not included in MARGIN
        d.ans <- d[MARGIN]  # define dim for answer array
        s.call <- ds[-MARGIN] # used to define permute
        s.ans <- ds[MARGIN]     # used to define permute
        d2 <- prod(d.ans)   # length of results object
        
        newX <- aperm(X, c(s.call, s.ans)) # permute X such that dims omitted from calc are first dim
        dim(newX) <- c(prod(d.call), d2) # voodoo. Preserves ommitted dim dimension but collapses the rest into 1
        
        maxes <- colMaxs(newX)
        ans <- maxes + log(colSums(exp( sweep(newX, 2, maxes, "-"))) )
        ans <- array(ans, d.ans)
        
        return(ans)
    }
    
     > microbenchmark(
    +     res1 <- apply(array4d, c(1,2,3), logSumExp),
    +     res2 <- log(rowSums(exp(array4d), dims=3)),
    +     res3 <- apply_logSumExp(array4d)
    + )
    Unit: milliseconds
                                              expr        min         lq       mean    median        uq       max
     res1 <- apply(array4d, c(1, 2, 3), logSumExp) 176.286670 213.882443 247.420334 236.44593 267.81127 486.41072
      res2 <- log(rowSums(exp(array4d), dims = 3))   4.664907   5.821601   7.588448   5.97765   7.47814  30.58002
                  res3 <- apply_logSumExp(array4d)  12.119875  14.673011  19.635265  15.20385  18.30471  90.59859
     neval cld
       100   c
       100 a  
       100  b 
    

    【讨论】:

      猜你喜欢
      • 2019-05-31
      • 2020-06-19
      • 2010-12-17
      • 1970-01-01
      • 2014-03-18
      • 2021-10-17
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多